mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
2 Commits
modular-re
...
sd35-loras
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1000300ed4 | ||
|
|
091b185ec8 |
@@ -665,6 +665,251 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=None):
|
||||||
|
new_state_dict = {}
|
||||||
|
|
||||||
|
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||||
|
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||||
|
def swap_scale_shift(weight):
|
||||||
|
shift, scale = weight.chunk(2, dim=0)
|
||||||
|
new_weight = torch.cat([scale, shift], dim=0)
|
||||||
|
return new_weight
|
||||||
|
|
||||||
|
def calculate_scales(key):
|
||||||
|
lora_rank = state_dict[f"{key}.lora_down.weight"].shape[0]
|
||||||
|
alpha = state_dict.pop(key + ".alpha")
|
||||||
|
scale = alpha / lora_rank
|
||||||
|
|
||||||
|
# calculate scale_down and scale_up
|
||||||
|
scale_down = scale
|
||||||
|
scale_up = 1.0
|
||||||
|
while scale_down * 2 < scale_up:
|
||||||
|
scale_down *= 2
|
||||||
|
scale_up /= 2
|
||||||
|
|
||||||
|
return scale_down, scale_up
|
||||||
|
|
||||||
|
def weight_is_sparse(key, rank, num_splits, up_weight):
|
||||||
|
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||||
|
|
||||||
|
is_sparse = False
|
||||||
|
requested_rank = rank
|
||||||
|
if rank % num_splits == 0:
|
||||||
|
requested_rank = rank // num_splits
|
||||||
|
is_sparse = True
|
||||||
|
i = 0
|
||||||
|
for j in range(len(dims)):
|
||||||
|
for k in range(len(dims)):
|
||||||
|
if j == k:
|
||||||
|
continue
|
||||||
|
is_sparse = is_sparse and torch.all(
|
||||||
|
up_weight[i : i + dims[j], k * requested_rank : (k + 1) * requested_rank] == 0
|
||||||
|
)
|
||||||
|
i += dims[j]
|
||||||
|
if is_sparse:
|
||||||
|
logger.info(f"weight is sparse: {key}")
|
||||||
|
|
||||||
|
return is_sparse, requested_rank
|
||||||
|
|
||||||
|
# handle only transformer blocks for now.
|
||||||
|
layers = set()
|
||||||
|
for k in state_dict:
|
||||||
|
if "joint_blocks" in k:
|
||||||
|
idx = int(k.split("_", 4)[-1].split("_", 1)[0])
|
||||||
|
layers.add(idx)
|
||||||
|
num_layers = max(layers) + 1
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
# norms
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.norm1.linear", f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
|
||||||
|
if not (i == num_layers - 1):
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(
|
||||||
|
f"transformer_blocks.{i}.norm1_context.linear",
|
||||||
|
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
|
||||||
|
)
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(
|
||||||
|
f"transformer_blocks.{i}.norm1_context.linear",
|
||||||
|
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
|
||||||
|
)
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_down.weight")) * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||||
|
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_up.weight")) * scale_up
|
||||||
|
)
|
||||||
|
|
||||||
|
# output projections
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.attn.to_out.0", f"lora_unet_joint_blocks_{i}_x_block_attn_proj")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
if not (i == num_layers - 1):
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.attn.to_add_out", f"lora_unet_joint_blocks_{i}_context_block_attn_proj")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
)
|
||||||
|
|
||||||
|
# ffs
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.ff.net.0.proj", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.ff.net.2", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
|
||||||
|
if not (i == num_layers - 1):
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.ff_context.net.0.proj", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
)
|
||||||
|
|
||||||
|
for diffusers_key, orig_key in [
|
||||||
|
(f"transformer_blocks.{i}.ff_context.net.2", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2")
|
||||||
|
]:
|
||||||
|
scale_down, scale_up = calculate_scales(orig_key)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||||
|
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||||
|
)
|
||||||
|
|
||||||
|
# core transformer blocks.
|
||||||
|
# sample blocks.
|
||||||
|
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv")
|
||||||
|
is_sparse, requested_rank = weight_is_sparse(
|
||||||
|
key=f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
|
||||||
|
rank=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight"].shape[0],
|
||||||
|
num_splits=3,
|
||||||
|
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight"],
|
||||||
|
)
|
||||||
|
num_splits = 3
|
||||||
|
sample_qkv_lora_down = (
|
||||||
|
state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
sample_qkv_lora_up = state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight") * scale_up
|
||||||
|
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
|
||||||
|
if not is_sparse:
|
||||||
|
for attn_k in ["to_q", "to_k", "to_v"]:
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
|
||||||
|
for attn_k, v in zip(["to_q", "to_k", "to_v"], torch.split(sample_qkv_lora_up, dims, dim=0)):
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
|
||||||
|
else:
|
||||||
|
# down_weight is chunked to each split
|
||||||
|
new_state_dict.update(
|
||||||
|
{
|
||||||
|
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
|
||||||
|
for k, v in zip(["to_q", "to_k", "to_v"], torch.chunk(sample_qkv_lora_down, num_splits, dim=0))
|
||||||
|
}
|
||||||
|
) # noqa: C416
|
||||||
|
|
||||||
|
# up_weight is sparse: only non-zero values are copied to each split
|
||||||
|
i = 0
|
||||||
|
for j, attn_k in enumerate(["to_q", "to_k", "to_v"]):
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
|
||||||
|
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
|
||||||
|
].contiguous()
|
||||||
|
i += dims[j]
|
||||||
|
|
||||||
|
# context blocks.
|
||||||
|
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv")
|
||||||
|
is_sparse, requested_rank = weight_is_sparse(
|
||||||
|
key=f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
|
||||||
|
rank=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight"].shape[0],
|
||||||
|
num_splits=3,
|
||||||
|
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight"],
|
||||||
|
)
|
||||||
|
num_splits = 3
|
||||||
|
sample_qkv_lora_down = (
|
||||||
|
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight") * scale_down
|
||||||
|
)
|
||||||
|
sample_qkv_lora_up = (
|
||||||
|
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight") * scale_up
|
||||||
|
)
|
||||||
|
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
|
||||||
|
if not is_sparse:
|
||||||
|
for attn_k in ["add_q_proj", "add_k_proj", "add_v_proj"]:
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
|
||||||
|
for attn_k, v in zip(
|
||||||
|
["add_q_proj", "add_k_proj", "add_v_proj"], torch.split(sample_qkv_lora_up, dims, dim=0)
|
||||||
|
):
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
|
||||||
|
else:
|
||||||
|
# down_weight is chunked to each split
|
||||||
|
new_state_dict.update(
|
||||||
|
{
|
||||||
|
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
|
||||||
|
for k, v in zip(
|
||||||
|
["add_q_proj", "add_k_proj", "add_v_proj"],
|
||||||
|
torch.chunk(sample_qkv_lora_down, num_splits, dim=0),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
) # noqa: C416
|
||||||
|
|
||||||
|
# up_weight is sparse: only non-zero values are copied to each split
|
||||||
|
i = 0
|
||||||
|
for j, attn_k in enumerate(["add_q_proj", "add_k_proj", "add_v_proj"]):
|
||||||
|
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
|
||||||
|
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
|
||||||
|
].contiguous()
|
||||||
|
i += dims[j]
|
||||||
|
|
||||||
|
if len(state_dict) > 0:
|
||||||
|
raise ValueError(f"`state_dict` should be at this point but has: {list(state_dict.keys())}.")
|
||||||
|
|
||||||
|
prefix = prefix or "transformer"
|
||||||
|
new_state_dict = {f"{prefix}.{k}": v for k, v in new_state_dict.items()}
|
||||||
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||||
converted_state_dict = {}
|
converted_state_dict = {}
|
||||||
original_state_dict_keys = list(original_state_dict.keys())
|
original_state_dict_keys = list(original_state_dict.keys())
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from .lora_conversion_utils import (
|
|||||||
_convert_bfl_flux_control_lora_to_diffusers,
|
_convert_bfl_flux_control_lora_to_diffusers,
|
||||||
_convert_kohya_flux_lora_to_diffusers,
|
_convert_kohya_flux_lora_to_diffusers,
|
||||||
_convert_non_diffusers_lora_to_diffusers,
|
_convert_non_diffusers_lora_to_diffusers,
|
||||||
|
_convert_non_diffusers_sd3_lora_to_diffusers,
|
||||||
_convert_xlabs_flux_lora_to_diffusers,
|
_convert_xlabs_flux_lora_to_diffusers,
|
||||||
_maybe_map_sgm_blocks_to_diffusers,
|
_maybe_map_sgm_blocks_to_diffusers,
|
||||||
)
|
)
|
||||||
@@ -1239,6 +1240,27 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
logger.warning(warn_msg)
|
logger.warning(warn_msg)
|
||||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||||
|
|
||||||
|
is_non_diffusers = any("lora_unet" in k for k in state_dict)
|
||||||
|
if is_non_diffusers:
|
||||||
|
has_only_transformer = all(k.startswith("lora_unet") for k in state_dict)
|
||||||
|
if not has_only_transformer:
|
||||||
|
state_dict = {k: v for k, v in state_dict.items() if k.startswith("lora_unet")}
|
||||||
|
logger.warning(
|
||||||
|
"Some keys in the LoRA checkpoint are not related to transformer blocks and we will filter them out during loading. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
|
||||||
|
)
|
||||||
|
|
||||||
|
all_joint_blocks = all("joint_blocks" in k for k in state_dict)
|
||||||
|
if not all_joint_blocks:
|
||||||
|
raise ValueError(
|
||||||
|
"LoRAs containing only transformer blocks are supported at this point. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
|
||||||
|
)
|
||||||
|
|
||||||
|
has_dual_attention_layers = any("attn2" in k for k in state_dict)
|
||||||
|
if has_dual_attention_layers:
|
||||||
|
raise ValueError("LoRA state dicts with dual attention layers are not supported.")
|
||||||
|
|
||||||
|
state_dict = _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=cls.transformer_name)
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def load_lora_weights(
|
def load_lora_weights(
|
||||||
@@ -1283,12 +1305,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||||
|
|
||||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||||
if not is_correct_format:
|
if not is_correct_format:
|
||||||
raise ValueError("Invalid LoRA checkpoint.")
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
|
|
||||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")}
|
||||||
if len(transformer_state_dict) > 0:
|
if len(transformer_state_dict) > 0:
|
||||||
self.load_lora_into_transformer(
|
self.load_lora_into_transformer(
|
||||||
state_dict,
|
state_dict,
|
||||||
@@ -1299,8 +1320,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No LoRA keys were found for the transformer.")
|
||||||
|
|
||||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.text_encoder_name}.")}
|
||||||
if len(text_encoder_state_dict) > 0:
|
if len(text_encoder_state_dict) > 0:
|
||||||
self.load_lora_into_text_encoder(
|
self.load_lora_into_text_encoder(
|
||||||
text_encoder_state_dict,
|
text_encoder_state_dict,
|
||||||
@@ -1312,8 +1335,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No LoRA keys were found for the first text encoder.")
|
||||||
|
|
||||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if k.startswith("text_encoder_2.")}
|
||||||
if len(text_encoder_2_state_dict) > 0:
|
if len(text_encoder_2_state_dict) > 0:
|
||||||
self.load_lora_into_text_encoder(
|
self.load_lora_into_text_encoder(
|
||||||
text_encoder_2_state_dict,
|
text_encoder_2_state_dict,
|
||||||
@@ -1325,6 +1350,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
_pipeline=self,
|
_pipeline=self,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No LoRA keys were found for the second text encoder.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_transformer(
|
def load_lora_into_transformer(
|
||||||
|
|||||||
Reference in New Issue
Block a user