mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Compare commits
120 Commits
attn-refac
...
controlnet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dc4d694c4 | ||
|
|
ca6895a114 | ||
|
|
b08a0a61ce | ||
|
|
26662de868 | ||
|
|
332cbfd303 | ||
|
|
5871ecc980 | ||
|
|
bf7afc2f78 | ||
|
|
c4ad76e16c | ||
|
|
ef430bfae9 | ||
|
|
4087dbfbb6 | ||
|
|
86f5980ce8 | ||
|
|
c6a04063cc | ||
|
|
567a2dee1a | ||
|
|
5ceb0a2f08 | ||
|
|
b42169482c | ||
|
|
13e8c87777 | ||
|
|
64284b1742 | ||
|
|
a054d80ceb | ||
|
|
8dcc44ba31 | ||
|
|
57d52b4e8e | ||
|
|
9cfce5f19e | ||
|
|
e1286db6d2 | ||
|
|
05b7f8b2ba | ||
|
|
87ee3728bc | ||
|
|
b1099e8b51 | ||
|
|
432fa6b65d | ||
|
|
70c0c68428 | ||
|
|
9699382311 | ||
|
|
a66a46847a | ||
|
|
f17befc1a0 | ||
|
|
dd0ce66cc4 | ||
|
|
367e6c0b25 | ||
|
|
ebec2119cf | ||
|
|
b35f61fac3 | ||
|
|
f7fde8a68d | ||
|
|
2027143f81 | ||
|
|
610be144b0 | ||
|
|
d901a9a04a | ||
|
|
8ad9b977f3 | ||
|
|
1bfbefba32 | ||
|
|
71f3c91ac2 | ||
|
|
33cfc2d64d | ||
|
|
8206ef02a2 | ||
|
|
e238f3a7a6 | ||
|
|
aa4f65f066 | ||
|
|
fa4782f3ec | ||
|
|
8f6608d670 | ||
|
|
11ddd6cecf | ||
|
|
d0e1cfb5d4 | ||
|
|
b3b7798a30 | ||
|
|
d16673242e | ||
|
|
11a85cdf25 | ||
|
|
5e5004da0d | ||
|
|
260bc7527e | ||
|
|
d88c806a5d | ||
|
|
95f09d8fb8 | ||
|
|
fbb2d7bf49 | ||
|
|
2baae10d26 | ||
|
|
e143979ad3 | ||
|
|
5bdb7bb25d | ||
|
|
0e42a2c850 | ||
|
|
e103f776c2 | ||
|
|
c35161dc9b | ||
|
|
d326f24fd5 | ||
|
|
101ceebe5a | ||
|
|
000f74cedb | ||
|
|
f9eb243c74 | ||
|
|
7c26e9037b | ||
|
|
9d43c953cc | ||
|
|
e871eeefd0 | ||
|
|
efec092b4d | ||
|
|
e2e547722c | ||
|
|
dc27a087dc | ||
|
|
c13e824570 | ||
|
|
182e4552a7 | ||
|
|
4c93de5db0 | ||
|
|
7e87bf935b | ||
|
|
6b6195fa8a | ||
|
|
13dffc3892 | ||
|
|
40480deb60 | ||
|
|
48257fb218 | ||
|
|
50f3f4a799 | ||
|
|
4436870fd9 | ||
|
|
e047c4e9bd | ||
|
|
58c9f985ae | ||
|
|
ae1a178b73 | ||
|
|
6295db5e17 | ||
|
|
a58abee3d5 | ||
|
|
12d7b5dfd9 | ||
|
|
00fea8a0e7 | ||
|
|
3924166bed | ||
|
|
c3e0dd830d | ||
|
|
e572736547 | ||
|
|
58604783b1 | ||
|
|
3ad63ea168 | ||
|
|
260d5cc619 | ||
|
|
8d19befc03 | ||
|
|
09003fb60c | ||
|
|
24a2551f66 | ||
|
|
6adc8d55d5 | ||
|
|
54d1508c5a | ||
|
|
e47b47dab6 | ||
|
|
04f663d664 | ||
|
|
dde7ed6431 | ||
|
|
df3dfe3668 | ||
|
|
4baa7e3945 | ||
|
|
a9dfd86311 | ||
|
|
86515e4491 | ||
|
|
070983480f | ||
|
|
c8ec943cba | ||
|
|
38fb6fe37b | ||
|
|
2257ba9dd3 | ||
|
|
6f9e14bcfc | ||
|
|
30dee21a34 | ||
|
|
e736960821 | ||
|
|
49327162c9 | ||
|
|
2d4ae0026d | ||
|
|
e9fe443cca | ||
|
|
9a78f038fa | ||
|
|
c7a369afd3 |
@@ -137,7 +137,6 @@ class PatchedLoraProjection(nn.Module):
|
||||
self.w_down = None
|
||||
|
||||
def forward(self, input):
|
||||
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
|
||||
if self.lora_scale is None:
|
||||
self.lora_scale = 1.0
|
||||
if self.lora_linear_layer is None:
|
||||
@@ -1008,19 +1007,12 @@ class LoraLoaderMixin:
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
controlnet=False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
@@ -1032,6 +1024,8 @@ class LoraLoaderMixin:
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
controlnet (`bool`, *optional*, defaults to False):
|
||||
If we're converting a ControlNet LoRA checkpoint.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
@@ -1143,20 +1137,21 @@ class LoraLoaderMixin:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
network_alphas = None
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
if not controlnet:
|
||||
if all(
|
||||
(
|
||||
k.startswith("lora_te_")
|
||||
or k.startswith("lora_unet_")
|
||||
or k.startswith("lora_te1_")
|
||||
or k.startswith("lora_te2_")
|
||||
)
|
||||
for k in state_dict.keys()
|
||||
):
|
||||
# Map SDXL blocks correctly.
|
||||
if unet_config is not None:
|
||||
# use unet config to remap block numbers
|
||||
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
|
||||
|
||||
return state_dict, network_alphas
|
||||
|
||||
@@ -1700,7 +1695,6 @@ class LoraLoaderMixin:
|
||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
||||
else:
|
||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
||||
|
||||
if "middle.block" in diffusers_name:
|
||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
||||
else:
|
||||
@@ -1835,6 +1829,7 @@ class LoraLoaderMixin:
|
||||
te_state_dict.update(te2_state_dict)
|
||||
|
||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||
|
||||
return new_state_dict, network_alphas
|
||||
|
||||
def unload_lora_weights(self):
|
||||
@@ -2517,3 +2512,105 @@ class FromOriginalControlnetMixin:
|
||||
controlnet.to(torch_dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
|
||||
class ControlLoRAMixin(LoraLoaderMixin):
|
||||
# Simplify ControlNet LoRA loading.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
|
||||
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint
|
||||
|
||||
state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs)
|
||||
controlnet_config = kwargs.pop("controlnet_config", None)
|
||||
if controlnet_config is None:
|
||||
raise ValueError("Must provide a `controlnet_config`.")
|
||||
|
||||
# ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond
|
||||
# to the ones belonging to the original state_dict (initialized from the underlying UNet).
|
||||
# So, we first map the LoRA parameters and then we load the remaining state_dict into
|
||||
# the ControlNet.
|
||||
converted_state_dict = convert_ldm_unet_checkpoint(
|
||||
state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True
|
||||
)
|
||||
|
||||
# Load whatever is matching.
|
||||
load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False)
|
||||
if not all("lora" in k for k in load_state_dict_results.unexpected_keys):
|
||||
raise ValueError(
|
||||
f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# Filter out the rest of the state_dict for handling LoRA.
|
||||
remaining_state_dict = {
|
||||
k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys
|
||||
}
|
||||
|
||||
# Handle LoRA.
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
lora_layers_list = []
|
||||
|
||||
all_keys = list(remaining_state_dict.keys())
|
||||
for key in all_keys:
|
||||
value = remaining_state_dict.pop(key)
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
if len(remaining_state_dict) > 0:
|
||||
raise ValueError(
|
||||
f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
attn_processor = self
|
||||
for sub_key in key.split("."):
|
||||
attn_processor = getattr(attn_processor, sub_key)
|
||||
|
||||
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
|
||||
# or add_{k,v,q,out_proj}_proj_lora layers.
|
||||
rank = value_dict["lora.down.weight"].shape[0]
|
||||
|
||||
if isinstance(attn_processor, LoRACompatibleConv):
|
||||
in_features = attn_processor.in_channels
|
||||
out_features = attn_processor.out_channels
|
||||
kernel_size = attn_processor.kernel_size
|
||||
|
||||
lora = LoRAConv2dLayer(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
rank=rank,
|
||||
kernel_size=kernel_size,
|
||||
stride=attn_processor.stride,
|
||||
padding=attn_processor.padding,
|
||||
# initial_weight=attn_processor.weight,
|
||||
# initial_bias=attn_processor.bias,
|
||||
)
|
||||
elif isinstance(attn_processor, LoRACompatibleLinear):
|
||||
lora = LoRALinearLayer(
|
||||
attn_processor.in_features,
|
||||
attn_processor.out_features,
|
||||
rank,
|
||||
# initial_weight=attn_processor.weight,
|
||||
# initial_bias=attn_processor.bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
||||
|
||||
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
|
||||
load_state_dict_results = lora.load_state_dict(value_dict, strict=False)
|
||||
if not all("initial" in k for k in load_state_dict_results.unexpected_keys):
|
||||
raise ValueError("Incorrect `value_dict` for the LoRA layer.")
|
||||
lora_layers_list.append((attn_processor, lora))
|
||||
|
||||
# set correct dtype & device
|
||||
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
for _, module in self.named_modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
# Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul).
|
||||
|
||||
@@ -19,7 +19,8 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import FromOriginalControlnetMixin
|
||||
from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin
|
||||
from ..models.lora import LoRACompatibleConv
|
||||
from ..utils import BaseOutput, logging
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -80,7 +81,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = nn.ModuleList([])
|
||||
|
||||
@@ -96,6 +97,7 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
|
||||
def forward(self, conditioning):
|
||||
embedding = self.conv_in(conditioning)
|
||||
print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}")
|
||||
embedding = F.silu(embedding)
|
||||
|
||||
for block in self.blocks:
|
||||
@@ -107,7 +109,9 @@ class ControlNetConditioningEmbedding(nn.Module):
|
||||
return embedding
|
||||
|
||||
|
||||
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
class ControlNetModel(
|
||||
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin
|
||||
):
|
||||
"""
|
||||
A ControlNet model.
|
||||
|
||||
@@ -247,7 +251,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in = nn.Conv2d(
|
||||
self.conv_in = LoRACompatibleConv(
|
||||
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
||||
)
|
||||
|
||||
@@ -719,6 +723,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
print(f"t_emb: {t_emb[0, :3]}")
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
@@ -726,6 +731,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
print(f"emb: {emb[0, :3]}")
|
||||
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
@@ -764,6 +771,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}")
|
||||
|
||||
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
|
||||
sample = sample + controlnet_cond
|
||||
|
||||
@@ -18,6 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..models.lora import LoRACompatibleLinear
|
||||
from .activations import get_activation
|
||||
|
||||
|
||||
@@ -166,10 +167,10 @@ class TimestepEmbedding(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
||||
self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
@@ -179,7 +180,7 @@ class TimestepEmbedding(nn.Module):
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
|
||||
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
|
||||
@@ -40,7 +40,17 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
rank=4,
|
||||
network_alpha=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
# initial_weight=None,
|
||||
# initial_bias=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
|
||||
@@ -52,6 +62,10 @@ class LoRALinearLayer(nn.Module):
|
||||
self.out_features = out_features
|
||||
self.in_features = in_features
|
||||
|
||||
# # Control-LoRA specific.
|
||||
# self.initial_weight = initial_weight
|
||||
# self.initial_bias = initial_bias
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
@@ -66,11 +80,32 @@ class LoRALinearLayer(nn.Module):
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
# else:
|
||||
# initial_weight = self.initial_weight
|
||||
# if initial_weight.device != hidden_states.device:
|
||||
# initial_weight = initial_weight.to(hidden_states.device)
|
||||
# return torch.nn.functional.linear(
|
||||
# hidden_states.to(dtype),
|
||||
# initial_weight
|
||||
# + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1)))
|
||||
# .reshape(self.initial_weight.shape)
|
||||
# .type(orig_dtype),
|
||||
# self.initial_bias,
|
||||
# )
|
||||
|
||||
|
||||
class LoRAConv2dLayer(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
|
||||
self,
|
||||
in_features,
|
||||
out_features,
|
||||
rank=4,
|
||||
kernel_size=(1, 1),
|
||||
stride=(1, 1),
|
||||
padding=0,
|
||||
network_alpha=None,
|
||||
# initial_weight=None,
|
||||
# initial_bias=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -84,6 +119,13 @@ class LoRAConv2dLayer(nn.Module):
|
||||
self.network_alpha = network_alpha
|
||||
self.rank = rank
|
||||
|
||||
# # Control-LoRA specific.
|
||||
# self.initial_weight = initial_weight
|
||||
# self.initial_bias = initial_bias
|
||||
# self.stride = stride
|
||||
# self.kernel_size = kernel_size
|
||||
# self.padding = padding
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
@@ -98,6 +140,20 @@ class LoRAConv2dLayer(nn.Module):
|
||||
up_hidden_states *= self.network_alpha / self.rank
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
# else:
|
||||
# initial_weight = self.initial_weight
|
||||
# if initial_weight.device != hidden_states.device:
|
||||
# initial_weight = initial_weight.to(hidden_states.device)
|
||||
# return torch.nn.functional.conv2d(
|
||||
# hidden_states,
|
||||
# initial_weight
|
||||
# + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1)))
|
||||
# .reshape(self.initial_weight.shape)
|
||||
# .type(orig_dtype),
|
||||
# self.initial_bias,
|
||||
# self.stride,
|
||||
# self.padding,
|
||||
# )
|
||||
|
||||
|
||||
class LoRACompatibleConv(nn.Conv2d):
|
||||
|
||||
@@ -104,7 +104,10 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
DiffusionPipeline,
|
||||
TextualInversionLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
||||
@@ -1067,6 +1070,7 @@ class StableDiffusionXLControlNetPipeline(
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
print(f"pooled_prompt_embeds: {pooled_prompt_embeds.shape}")
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
|
||||
)
|
||||
|
||||
@@ -377,11 +377,21 @@ def create_ldm_bert_config(original_config):
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||
checkpoint,
|
||||
config,
|
||||
path=None,
|
||||
extract_ema=False,
|
||||
controlnet=False,
|
||||
skip_extract_state_dict=False,
|
||||
controlnet_lora=False,
|
||||
):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
if not controlnet and controlnet_lora:
|
||||
raise ValueError(f"`controlnet_lora` cannot be done with `controlnet` set to {controlnet}.")
|
||||
if controlnet and controlnet_lora:
|
||||
skip_extract_state_dict = True
|
||||
|
||||
if skip_extract_state_dict:
|
||||
unet_state_dict = checkpoint
|
||||
@@ -419,10 +429,22 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
if controlnet_lora:
|
||||
# Safe to pop as it doesn't have anything.
|
||||
_ = unet_state_dict.pop("lora_controlnet")
|
||||
|
||||
if not controlnet_lora:
|
||||
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
else:
|
||||
new_checkpoint["time_embedding.linear_1.lora_down.weight"] = unet_state_dict["time_embed.0.down"]
|
||||
new_checkpoint["time_embedding.linear_1.lora_up.weight"] = unet_state_dict["time_embed.0.up"]
|
||||
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
||||
new_checkpoint["time_embedding.linear_2.lora_down.weight"] = unet_state_dict["time_embed.2.down"]
|
||||
new_checkpoint["time_embedding.linear_2.lora_up.weight"] = unet_state_dict["time_embed.2.up"]
|
||||
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
||||
|
||||
if config["class_embed_type"] is None:
|
||||
# No parameters to port
|
||||
@@ -436,13 +458,26 @@ def convert_ldm_unet_checkpoint(
|
||||
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
||||
|
||||
if config["addition_embed_type"] == "text_time":
|
||||
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
if not controlnet_lora:
|
||||
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
||||
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
else:
|
||||
new_checkpoint["add_embedding.linear_1.lora_down.weight"] = unet_state_dict["label_emb.0.0.down"]
|
||||
new_checkpoint["add_embedding.linear_1.lora_up.weight"] = unet_state_dict["label_emb.0.0.up"]
|
||||
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
||||
new_checkpoint["add_embedding.linear_2.lora_down.weight"] = unet_state_dict["label_emb.0.2.down"]
|
||||
new_checkpoint["add_embedding.linear_2.lora_up.weight"] = unet_state_dict["label_emb.0.2.up"]
|
||||
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
||||
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
if not controlnet_lora:
|
||||
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
else:
|
||||
new_checkpoint["conv_in.lora_down.weight"] = unet_state_dict["input_blocks.0.0.down"]
|
||||
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
||||
new_checkpoint["conv_in.lora_up.weight"] = unet_state_dict["input_blocks.0.0.up"]
|
||||
|
||||
if not controlnet:
|
||||
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
||||
@@ -588,8 +623,9 @@ def convert_ldm_unet_checkpoint(
|
||||
orig_index += 2
|
||||
|
||||
diffusers_index = 0
|
||||
diffusers_index_limit = 6
|
||||
|
||||
while diffusers_index < 6:
|
||||
while diffusers_index < diffusers_index_limit:
|
||||
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
@@ -599,12 +635,13 @@ def convert_ldm_unet_checkpoint(
|
||||
diffusers_index += 1
|
||||
orig_index += 2
|
||||
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
if not controlnet_lora:
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.weight"
|
||||
)
|
||||
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
||||
f"input_hint_block.{orig_index}.bias"
|
||||
)
|
||||
|
||||
# down blocks
|
||||
for i in range(num_input_blocks):
|
||||
@@ -615,6 +652,21 @@ def convert_ldm_unet_checkpoint(
|
||||
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
||||
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
||||
|
||||
if controlnet_lora:
|
||||
modified_new_checkpoint = {}
|
||||
down_pattern = r"\.down$"
|
||||
up_pattern = r"\.up$"
|
||||
|
||||
for key in new_checkpoint:
|
||||
new_key = key
|
||||
new_key = re.sub(down_pattern, ".lora.down.weight", new_key)
|
||||
new_key = re.sub(up_pattern, ".lora.up.weight", new_key)
|
||||
new_key = new_key.replace("lora_down", "lora.down")
|
||||
new_key = new_key.replace("lora_up", "lora.up")
|
||||
modified_new_checkpoint[new_key] = new_checkpoint[key]
|
||||
|
||||
new_checkpoint = modified_new_checkpoint
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user