Compare commits

...

120 Commits

Author SHA1 Message Date
sayakpaul
6dc4d694c4 debug 2023-10-10 09:29:01 +02:00
sayakpaul
ca6895a114 debug 2023-10-09 22:07:41 +02:00
sayakpaul
b08a0a61ce debug 2023-10-09 22:03:53 +02:00
sayakpaul
26662de868 debug 2023-10-09 21:58:17 +02:00
sayakpaul
332cbfd303 debug 2023-10-09 21:56:33 +02:00
sayakpaul
5871ecc980 remove dtype of t from commit trail. 2023-10-09 17:13:29 +02:00
sayakpaul
bf7afc2f78 remove dtype of t from commit trail. 2023-10-09 17:11:08 +02:00
sayakpaul
c4ad76e16c have t printed. 2023-10-09 17:00:44 +02:00
sayakpaul
ef430bfae9 step by step debug 2023-10-09 16:52:55 +02:00
sayakpaul
4087dbfbb6 step by step debug 2023-10-09 15:36:27 +02:00
Sayak Paul
86f5980ce8 change class name 2023-09-28 14:28:51 +05:30
Sayak Paul
c6a04063cc remove print 2023-09-28 13:14:18 +05:30
Sayak Paul
567a2dee1a log 2023-09-28 12:31:52 +05:30
Sayak Paul
5ceb0a2f08 log 2023-09-28 12:01:49 +05:30
Sayak Paul
b42169482c another 2023-09-28 11:55:19 +05:30
Sayak Paul
13e8c87777 better conditioning 2023-09-28 11:19:18 +05:30
Sayak Paul
64284b1742 make strict loading false 2023-09-28 11:14:59 +05:30
Sayak Paul
a054d80ceb better support? 2023-09-28 11:11:19 +05:30
sayakpaul
8dcc44ba31 debugging 2023-09-19 09:08:24 +01:00
sayakpaul
57d52b4e8e debugging 2023-09-19 09:08:04 +01:00
sayakpaul
9cfce5f19e debugging 2023-09-18 23:13:35 +01:00
sayakpaul
e1286db6d2 debugging 2023-09-18 23:11:33 +01:00
sayakpaul
05b7f8b2ba debugging 2023-09-18 22:55:49 +01:00
sayakpaul
87ee3728bc debugging 2023-09-18 22:49:02 +01:00
sayakpaul
b1099e8b51 minor clean up 2023-09-18 12:38:56 +01:00
sayakpaul
432fa6b65d debugging 2023-09-18 11:58:45 +01:00
sayakpaul
70c0c68428 debugging 2023-09-18 11:57:05 +01:00
sayakpaul
9699382311 debugging 2023-09-18 11:55:12 +01:00
sayakpaul
a66a46847a debugging 2023-09-18 11:36:23 +01:00
sayakpaul
f17befc1a0 fix: doc 2023-09-18 11:17:27 +01:00
Sayak Paul
dd0ce66cc4 make style 2023-09-05 15:04:00 +05:30
Sayak Paul
367e6c0b25 remove prints. 2023-09-05 14:45:54 +05:30
Sayak Paul
ebec2119cf fix: embeddings. 2023-09-05 13:25:17 +05:30
Sayak Paul
b35f61fac3 fix: embeddings. 2023-09-05 13:23:42 +05:30
Sayak Paul
f7fde8a68d fix: embeddings. 2023-09-05 13:19:59 +05:30
Sayak Paul
2027143f81 sanity 2023-09-05 13:17:09 +05:30
Sayak Paul
610be144b0 sanity 2023-09-05 13:15:09 +05:30
Sayak Paul
d901a9a04a sanity 2023-09-05 13:10:31 +05:30
Sayak Paul
8ad9b977f3 better state_dict munging 2023-09-05 13:01:35 +05:30
Sayak Paul
1bfbefba32 better state_dict munging 2023-09-05 13:00:57 +05:30
Sayak Paul
71f3c91ac2 better state_dict munging 2023-09-05 12:59:32 +05:30
Sayak Paul
33cfc2d64d debugging 2023-09-05 12:54:47 +05:30
Sayak Paul
8206ef02a2 debugging 2023-09-05 12:52:24 +05:30
Sayak Paul
e238f3a7a6 debugging 2023-09-05 12:48:14 +05:30
Sayak Paul
aa4f65f066 debugging 2023-09-05 12:47:07 +05:30
Sayak Paul
fa4782f3ec debugging 2023-09-05 12:45:49 +05:30
Sayak Paul
8f6608d670 debugging 2023-09-05 12:42:04 +05:30
Sayak Paul
11ddd6cecf debugging 2023-09-05 12:34:43 +05:30
Sayak Paul
d0e1cfb5d4 debugging 2023-09-05 12:30:27 +05:30
Sayak Paul
b3b7798a30 debugging 2023-09-05 12:26:48 +05:30
Sayak Paul
d16673242e empty lora controlnet key 2023-09-05 12:17:26 +05:30
Sayak Paul
11a85cdf25 empty lora controlnet key 2023-09-05 12:15:47 +05:30
Sayak Paul
5e5004da0d fix: exception raise/. 2023-09-05 12:10:54 +05:30
Sayak Paul
260bc7527e better modularity 2023-09-05 12:06:27 +05:30
Sayak Paul
d88c806a5d better simplicity. 2023-09-05 11:46:52 +05:30
Sayak Paul
95f09d8fb8 remove unneeded stuff. 2023-09-05 11:24:46 +05:30
Sayak Paul
fbb2d7bf49 Merge branch 'main' into controlnet-sai 2023-09-05 11:17:14 +05:30
Sayak Paul
2baae10d26 remove unnecessary stuff from loaders.py 2023-09-05 11:16:37 +05:30
Sayak Paul
e143979ad3 changes 2023-09-05 11:11:25 +05:30
Sayak Paul
5bdb7bb25d changes 2023-09-05 10:31:54 +05:30
Sayak Paul
0e42a2c850 changes 2023-09-05 10:27:02 +05:30
Sayak Paul
e103f776c2 changes 2023-09-05 10:25:02 +05:30
Sayak Paul
c35161dc9b changes 2023-09-05 10:19:19 +05:30
Sayak Paul
d326f24fd5 changes 2023-09-05 10:06:42 +05:30
Sayak Paul
101ceebe5a changes 2023-09-05 10:01:15 +05:30
Sayak Paul
000f74cedb changes 2023-09-05 09:55:46 +05:30
Sayak Paul
f9eb243c74 changes 2023-09-05 09:53:06 +05:30
Sayak Paul
7c26e9037b changes 2023-09-05 09:45:22 +05:30
Sayak Paul
9d43c953cc changes 2023-09-05 09:11:56 +05:30
Sayak Paul
e871eeefd0 changes 2023-09-05 09:04:21 +05:30
Sayak Paul
efec092b4d changes 2023-09-05 09:01:51 +05:30
Sayak Paul
e2e547722c changes 2023-09-05 08:59:54 +05:30
Sayak Paul
dc27a087dc changes 2023-09-05 08:56:42 +05:30
Sayak Paul
c13e824570 changes 2023-09-05 08:51:03 +05:30
Sayak Paul
182e4552a7 changes 2023-09-05 08:48:54 +05:30
Sayak Paul
4c93de5db0 changes 2023-09-05 08:46:59 +05:30
Sayak Paul
7e87bf935b changes 2023-09-05 08:45:01 +05:30
Sayak Paul
6b6195fa8a debugging 2023-09-05 08:12:38 +05:30
Sayak Paul
13dffc3892 debugging 2023-09-05 08:00:20 +05:30
Sayak Paul
40480deb60 more stuff 2023-08-24 07:43:36 +05:30
Sayak Paul
48257fb218 fix 2023-08-22 17:25:44 +05:30
Sayak Paul
50f3f4a799 make method a part of it now 2023-08-22 17:20:00 +05:30
Sayak Paul
4436870fd9 remove print 2023-08-22 17:07:06 +05:30
Sayak Paul
e047c4e9bd better state dict munging 2023-08-22 17:05:24 +05:30
Sayak Paul
58c9f985ae debugging 2023-08-22 17:01:46 +05:30
Sayak Paul
ae1a178b73 debugging 2023-08-22 16:59:28 +05:30
Sayak Paul
6295db5e17 debugging 2023-08-22 16:53:55 +05:30
Sayak Paul
a58abee3d5 debugging 2023-08-22 16:49:13 +05:30
Sayak Paul
12d7b5dfd9 debugging 2023-08-22 16:44:31 +05:30
Sayak Paul
00fea8a0e7 debugging 2023-08-22 16:42:12 +05:30
Sayak Paul
3924166bed debugging 2023-08-22 16:38:02 +05:30
Sayak Paul
c3e0dd830d debugging 2023-08-22 16:33:27 +05:30
Sayak Paul
e572736547 debugging 2023-08-22 16:27:16 +05:30
Sayak Paul
58604783b1 debugging 2023-08-22 16:22:38 +05:30
Sayak Paul
3ad63ea168 debugging 2023-08-22 16:17:04 +05:30
Sayak Paul
260d5cc619 debugging 2023-08-22 16:09:53 +05:30
Sayak Paul
8d19befc03 debugging 2023-08-22 16:08:30 +05:30
Sayak Paul
09003fb60c debugging 2023-08-22 16:02:58 +05:30
Sayak Paul
24a2551f66 debugging 2023-08-22 16:00:19 +05:30
Sayak Paul
6adc8d55d5 successful LoRA state dict parsing. 2023-08-22 15:49:51 +05:30
Sayak Paul
54d1508c5a successful LoRA state dict parsing. 2023-08-22 15:41:59 +05:30
Sayak Paul
e47b47dab6 debugging 2023-08-22 15:39:41 +05:30
Sayak Paul
04f663d664 debugging 2023-08-22 15:34:54 +05:30
Sayak Paul
dde7ed6431 debugging 2023-08-22 15:32:16 +05:30
Sayak Paul
df3dfe3668 debugging 2023-08-22 15:30:42 +05:30
Sayak Paul
4baa7e3945 debugging 2023-08-22 15:17:26 +05:30
Sayak Paul
a9dfd86311 debugging 2023-08-22 14:42:20 +05:30
Sayak Paul
86515e4491 seeing. 2023-08-22 13:52:46 +05:30
Sayak Paul
070983480f simplify condition. 2023-08-22 13:47:50 +05:30
Sayak Paul
c8ec943cba remove unnecessary statements. 2023-08-22 13:44:10 +05:30
Sayak Paul
38fb6fe37b debugging 2023-08-22 13:38:42 +05:30
Sayak Paul
2257ba9dd3 debugging 2023-08-22 13:28:21 +05:30
Sayak Paul
6f9e14bcfc debugging 2023-08-22 13:25:10 +05:30
Sayak Paul
30dee21a34 let's see 2023-08-22 13:20:14 +05:30
Sayak Paul
e736960821 sai controlnet 2023-08-22 11:33:43 +05:30
Sayak Paul
49327162c9 exploring 2023-08-22 11:29:35 +05:30
Sayak Paul
2d4ae0026d relax check. 2023-08-22 11:25:09 +05:30
Sayak Paul
e9fe443cca wondering' 2023-08-18 17:53:01 +05:30
Sayak Paul
9a78f038fa wondering' 2023-08-18 17:48:24 +05:30
Sayak Paul
c7a369afd3 make controlnet sublcass from a loraloader 2023-08-18 16:55:16 +05:30
6 changed files with 270 additions and 52 deletions

View File

@@ -137,7 +137,6 @@ class PatchedLoraProjection(nn.Module):
self.w_down = None
def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
@@ -1008,19 +1007,12 @@ class LoraLoaderMixin:
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
controlnet=False,
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
@@ -1032,6 +1024,8 @@ class LoraLoaderMixin:
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
controlnet (`bool`, *optional*, defaults to False):
If we're converting a ControlNet LoRA checkpoint.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
@@ -1143,20 +1137,21 @@ class LoraLoaderMixin:
state_dict = pretrained_model_name_or_path_or_dict
network_alphas = None
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
if not controlnet:
if all(
(
k.startswith("lora_te_")
or k.startswith("lora_unet_")
or k.startswith("lora_te1_")
or k.startswith("lora_te2_")
)
for k in state_dict.keys()
):
# Map SDXL blocks correctly.
if unet_config is not None:
# use unet config to remap block numbers
state_dict = cls._maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict)
return state_dict, network_alphas
@@ -1700,7 +1695,6 @@ class LoraLoaderMixin:
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
if "middle.block" in diffusers_name:
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
else:
@@ -1835,6 +1829,7 @@ class LoraLoaderMixin:
te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alphas
def unload_lora_weights(self):
@@ -2517,3 +2512,105 @@ class FromOriginalControlnetMixin:
controlnet.to(torch_dtype=torch_dtype)
return controlnet
class ControlLoRAMixin(LoraLoaderMixin):
# Simplify ControlNet LoRA loading.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs):
from .models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
from .pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint
state_dict, _ = self.lora_state_dict(pretrained_model_name_or_path_or_dict, controlnet=True, **kwargs)
controlnet_config = kwargs.pop("controlnet_config", None)
if controlnet_config is None:
raise ValueError("Must provide a `controlnet_config`.")
# ControlNet LoRA has a mix of things. Some parameters correspond to LoRA and some correspond
# to the ones belonging to the original state_dict (initialized from the underlying UNet).
# So, we first map the LoRA parameters and then we load the remaining state_dict into
# the ControlNet.
converted_state_dict = convert_ldm_unet_checkpoint(
state_dict, controlnet=True, config=controlnet_config, skip_extract_state_dict=True, controlnet_lora=True
)
# Load whatever is matching.
load_state_dict_results = self.load_state_dict(converted_state_dict, strict=False)
if not all("lora" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError(
f"The unexpected keys must only belong to LoRA parameters at this point, but found the following keys that are non-LoRA\n: {load_state_dict_results.unexpected_keys}"
)
# Filter out the rest of the state_dict for handling LoRA.
remaining_state_dict = {
k: v for k, v in converted_state_dict.items() if k in load_state_dict_results.unexpected_keys
}
# Handle LoRA.
lora_grouped_dict = defaultdict(dict)
lora_layers_list = []
all_keys = list(remaining_state_dict.keys())
for key in all_keys:
value = remaining_state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
if len(remaining_state_dict) > 0:
raise ValueError(
f"The `remaining_state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
)
for key, value_dict in lora_grouped_dict.items():
attn_processor = self
for sub_key in key.split("."):
attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
elif isinstance(attn_processor, LoRACompatibleLinear):
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
# initial_weight=attn_processor.weight,
# initial_bias=attn_processor.bias,
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
load_state_dict_results = lora.load_state_dict(value_dict, strict=False)
if not all("initial" in k for k in load_state_dict_results.unexpected_keys):
raise ValueError("Incorrect `value_dict` for the LoRA layer.")
lora_layers_list.append((attn_processor, lora))
# set correct dtype & device
lora_layers_list = [(t, l.to(device=self.device, dtype=self.dtype)) for t, l in lora_layers_list]
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
def unload_lora_weights(self):
for _, module in self.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
# Implement `fuse_lora()` and `unfuse_lora()` (sayakpaul).

View File

@@ -19,7 +19,8 @@ from torch import nn
from torch.nn import functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import FromOriginalControlnetMixin
from ..loaders import ControlLoRAMixin, FromOriginalControlnetMixin, UNet2DConditionLoadersMixin
from ..models.lora import LoRACompatibleConv
from ..utils import BaseOutput, logging
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -80,7 +81,7 @@ class ControlNetConditioningEmbedding(nn.Module):
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.conv_in = LoRACompatibleConv(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
self.blocks = nn.ModuleList([])
@@ -96,6 +97,7 @@ class ControlNetConditioningEmbedding(nn.Module):
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
print(f"From conv_in embedding of ControlNet: {embedding[0, :5, :5, -1]}")
embedding = F.silu(embedding)
for block in self.blocks:
@@ -107,7 +109,9 @@ class ControlNetConditioningEmbedding(nn.Module):
return embedding
class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
class ControlNetModel(
ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, FromOriginalControlnetMixin, ControlLoRAMixin
):
"""
A ControlNet model.
@@ -247,7 +251,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
self.conv_in = LoRACompatibleConv(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
@@ -719,6 +723,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
print(f"t_emb: {t_emb[0, :3]}")
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
@@ -726,6 +731,8 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
print(f"emb: {emb[0, :3]}")
aug_emb = None
if self.class_embedding is not None:
@@ -764,6 +771,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
# 2. pre-process
sample = self.conv_in(sample)
print(f"From ControlNet conv_in: {sample[0, :5, :5, -1]}")
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
sample = sample + controlnet_cond

View File

@@ -18,6 +18,7 @@ import numpy as np
import torch
from torch import nn
from ..models.lora import LoRACompatibleLinear
from .activations import get_activation
@@ -166,10 +167,10 @@ class TimestepEmbedding(nn.Module):
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
self.cond_proj = LoRACompatibleLinear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
@@ -179,7 +180,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None

View File

@@ -40,7 +40,17 @@ def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
def __init__(
self,
in_features,
out_features,
rank=4,
network_alpha=None,
device=None,
dtype=None,
# initial_weight=None,
# initial_bias=None,
):
super().__init__()
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
@@ -52,6 +62,10 @@ class LoRALinearLayer(nn.Module):
self.out_features = out_features
self.in_features = in_features
# # Control-LoRA specific.
# self.initial_weight = initial_weight
# self.initial_bias = initial_bias
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
@@ -66,11 +80,32 @@ class LoRALinearLayer(nn.Module):
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
# else:
# initial_weight = self.initial_weight
# if initial_weight.device != hidden_states.device:
# initial_weight = initial_weight.to(hidden_states.device)
# return torch.nn.functional.linear(
# hidden_states.to(dtype),
# initial_weight
# + (torch.mm(self.up.weight.data.flatten(start_dim=1), self.down.weight.data.flatten(start_dim=1)))
# .reshape(self.initial_weight.shape)
# .type(orig_dtype),
# self.initial_bias,
# )
class LoRAConv2dLayer(nn.Module):
def __init__(
self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None
self,
in_features,
out_features,
rank=4,
kernel_size=(1, 1),
stride=(1, 1),
padding=0,
network_alpha=None,
# initial_weight=None,
# initial_bias=None,
):
super().__init__()
@@ -84,6 +119,13 @@ class LoRAConv2dLayer(nn.Module):
self.network_alpha = network_alpha
self.rank = rank
# # Control-LoRA specific.
# self.initial_weight = initial_weight
# self.initial_bias = initial_bias
# self.stride = stride
# self.kernel_size = kernel_size
# self.padding = padding
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
@@ -98,6 +140,20 @@ class LoRAConv2dLayer(nn.Module):
up_hidden_states *= self.network_alpha / self.rank
return up_hidden_states.to(orig_dtype)
# else:
# initial_weight = self.initial_weight
# if initial_weight.device != hidden_states.device:
# initial_weight = initial_weight.to(hidden_states.device)
# return torch.nn.functional.conv2d(
# hidden_states,
# initial_weight
# + (torch.mm(self.up.weight.flatten(start_dim=1), self.down.weight.flatten(start_dim=1)))
# .reshape(self.initial_weight.shape)
# .type(orig_dtype),
# self.initial_bias,
# self.stride,
# self.padding,
# )
class LoRACompatibleConv(nn.Conv2d):

View File

@@ -104,7 +104,10 @@ EXAMPLE_DOC_STRING = """
class StableDiffusionXLControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
DiffusionPipeline,
TextualInversionLoaderMixin,
LoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -1067,6 +1070,7 @@ class StableDiffusionXLControlNetPipeline(
target_size = target_size or (height, width)
add_text_embeds = pooled_prompt_embeds
print(f"pooled_prompt_embeds: {pooled_prompt_embeds.shape}")
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
)

View File

@@ -377,11 +377,21 @@ def create_ldm_bert_config(original_config):
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
checkpoint,
config,
path=None,
extract_ema=False,
controlnet=False,
skip_extract_state_dict=False,
controlnet_lora=False,
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
if not controlnet and controlnet_lora:
raise ValueError(f"`controlnet_lora` cannot be done with `controlnet` set to {controlnet}.")
if controlnet and controlnet_lora:
skip_extract_state_dict = True
if skip_extract_state_dict:
unet_state_dict = checkpoint
@@ -419,10 +429,22 @@ def convert_ldm_unet_checkpoint(
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if controlnet_lora:
# Safe to pop as it doesn't have anything.
_ = unet_state_dict.pop("lora_controlnet")
if not controlnet_lora:
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
else:
new_checkpoint["time_embedding.linear_1.lora_down.weight"] = unet_state_dict["time_embed.0.down"]
new_checkpoint["time_embedding.linear_1.lora_up.weight"] = unet_state_dict["time_embed.0.up"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.lora_down.weight"] = unet_state_dict["time_embed.2.down"]
new_checkpoint["time_embedding.linear_2.lora_up.weight"] = unet_state_dict["time_embed.2.up"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
@@ -436,13 +458,26 @@ def convert_ldm_unet_checkpoint(
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
if config["addition_embed_type"] == "text_time":
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
if not controlnet_lora:
new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
new_checkpoint["add_embedding.linear_1.lora_down.weight"] = unet_state_dict["label_emb.0.0.down"]
new_checkpoint["add_embedding.linear_1.lora_up.weight"] = unet_state_dict["label_emb.0.0.up"]
new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["add_embedding.linear_2.lora_down.weight"] = unet_state_dict["label_emb.0.2.down"]
new_checkpoint["add_embedding.linear_2.lora_up.weight"] = unet_state_dict["label_emb.0.2.up"]
new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
if not controlnet_lora:
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
else:
new_checkpoint["conv_in.lora_down.weight"] = unet_state_dict["input_blocks.0.0.down"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
new_checkpoint["conv_in.lora_up.weight"] = unet_state_dict["input_blocks.0.0.up"]
if not controlnet:
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
@@ -588,8 +623,9 @@ def convert_ldm_unet_checkpoint(
orig_index += 2
diffusers_index = 0
diffusers_index_limit = 6
while diffusers_index < 6:
while diffusers_index < diffusers_index_limit:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
@@ -599,12 +635,13 @@ def convert_ldm_unet_checkpoint(
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
if not controlnet_lora:
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
@@ -615,6 +652,21 @@ def convert_ldm_unet_checkpoint(
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
if controlnet_lora:
modified_new_checkpoint = {}
down_pattern = r"\.down$"
up_pattern = r"\.up$"
for key in new_checkpoint:
new_key = key
new_key = re.sub(down_pattern, ".lora.down.weight", new_key)
new_key = re.sub(up_pattern, ".lora.up.weight", new_key)
new_key = new_key.replace("lora_down", "lora.down")
new_key = new_key.replace("lora_up", "lora.up")
modified_new_checkpoint[new_key] = new_checkpoint[key]
new_checkpoint = modified_new_checkpoint
return new_checkpoint