mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
2 Commits
onnx-cpu-d
...
sd35-loras
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1000300ed4 | ||
|
|
091b185ec8 |
@@ -665,6 +665,251 @@ def _convert_xlabs_flux_lora_to_diffusers(old_state_dict):
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=None):
|
||||
new_state_dict = {}
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
def calculate_scales(key):
|
||||
lora_rank = state_dict[f"{key}.lora_down.weight"].shape[0]
|
||||
alpha = state_dict.pop(key + ".alpha")
|
||||
scale = alpha / lora_rank
|
||||
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
|
||||
return scale_down, scale_up
|
||||
|
||||
def weight_is_sparse(key, rank, num_splits, up_weight):
|
||||
dims = [up_weight.shape[0] // num_splits] * num_splits
|
||||
|
||||
is_sparse = False
|
||||
requested_rank = rank
|
||||
if rank % num_splits == 0:
|
||||
requested_rank = rank // num_splits
|
||||
is_sparse = True
|
||||
i = 0
|
||||
for j in range(len(dims)):
|
||||
for k in range(len(dims)):
|
||||
if j == k:
|
||||
continue
|
||||
is_sparse = is_sparse and torch.all(
|
||||
up_weight[i : i + dims[j], k * requested_rank : (k + 1) * requested_rank] == 0
|
||||
)
|
||||
i += dims[j]
|
||||
if is_sparse:
|
||||
logger.info(f"weight is sparse: {key}")
|
||||
|
||||
return is_sparse, requested_rank
|
||||
|
||||
# handle only transformer blocks for now.
|
||||
layers = set()
|
||||
for k in state_dict:
|
||||
if "joint_blocks" in k:
|
||||
idx = int(k.split("_", 4)[-1].split("_", 1)[0])
|
||||
layers.add(idx)
|
||||
num_layers = max(layers) + 1
|
||||
|
||||
for i in range(num_layers):
|
||||
# norms
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.norm1.linear", f"lora_unet_joint_blocks_{i}_x_block_adaLN_modulation_1")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
|
||||
if not (i == num_layers - 1):
|
||||
for diffusers_key, orig_key in [
|
||||
(
|
||||
f"transformer_blocks.{i}.norm1_context.linear",
|
||||
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
|
||||
)
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
)
|
||||
else:
|
||||
for diffusers_key, orig_key in [
|
||||
(
|
||||
f"transformer_blocks.{i}.norm1_context.linear",
|
||||
f"lora_unet_joint_blocks_{i}_context_block_adaLN_modulation_1",
|
||||
)
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_down.weight")) * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||
swap_scale_shift(state_dict.pop(f"{orig_key}.lora_up.weight")) * scale_up
|
||||
)
|
||||
|
||||
# output projections
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.attn.to_out.0", f"lora_unet_joint_blocks_{i}_x_block_attn_proj")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
if not (i == num_layers - 1):
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.attn.to_add_out", f"lora_unet_joint_blocks_{i}_context_block_attn_proj")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
)
|
||||
|
||||
# ffs
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.ff.net.0.proj", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc1")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.ff.net.2", f"lora_unet_joint_blocks_{i}_x_block_mlp_fc2")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
|
||||
if not (i == num_layers - 1):
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.ff_context.net.0.proj", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc1")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
)
|
||||
|
||||
for diffusers_key, orig_key in [
|
||||
(f"transformer_blocks.{i}.ff_context.net.2", f"lora_unet_joint_blocks_{i}_context_block_mlp_fc2")
|
||||
]:
|
||||
scale_down, scale_up = calculate_scales(orig_key)
|
||||
new_state_dict[f"{diffusers_key}.lora_A.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_down.weight") * scale_down
|
||||
)
|
||||
new_state_dict[f"{diffusers_key}.lora_B.weight"] = (
|
||||
state_dict.pop(f"{orig_key}.lora_up.weight") * scale_up
|
||||
)
|
||||
|
||||
# core transformer blocks.
|
||||
# sample blocks.
|
||||
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv")
|
||||
is_sparse, requested_rank = weight_is_sparse(
|
||||
key=f"lora_unet_joint_blocks_{i}_x_block_attn_qkv",
|
||||
rank=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight"].shape[0],
|
||||
num_splits=3,
|
||||
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight"],
|
||||
)
|
||||
num_splits = 3
|
||||
sample_qkv_lora_down = (
|
||||
state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_down.weight") * scale_down
|
||||
)
|
||||
sample_qkv_lora_up = state_dict.pop(f"lora_unet_joint_blocks_{i}_x_block_attn_qkv.lora_up.weight") * scale_up
|
||||
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
|
||||
if not is_sparse:
|
||||
for attn_k in ["to_q", "to_k", "to_v"]:
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
|
||||
for attn_k, v in zip(["to_q", "to_k", "to_v"], torch.split(sample_qkv_lora_up, dims, dim=0)):
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
new_state_dict.update(
|
||||
{
|
||||
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
|
||||
for k, v in zip(["to_q", "to_k", "to_v"], torch.chunk(sample_qkv_lora_down, num_splits, dim=0))
|
||||
}
|
||||
) # noqa: C416
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j, attn_k in enumerate(["to_q", "to_k", "to_v"]):
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
|
||||
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
|
||||
].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
# context blocks.
|
||||
scale_down, scale_up = calculate_scales(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv")
|
||||
is_sparse, requested_rank = weight_is_sparse(
|
||||
key=f"lora_unet_joint_blocks_{i}_context_block_attn_qkv",
|
||||
rank=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight"].shape[0],
|
||||
num_splits=3,
|
||||
up_weight=state_dict[f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight"],
|
||||
)
|
||||
num_splits = 3
|
||||
sample_qkv_lora_down = (
|
||||
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_down.weight") * scale_down
|
||||
)
|
||||
sample_qkv_lora_up = (
|
||||
state_dict.pop(f"lora_unet_joint_blocks_{i}_context_block_attn_qkv.lora_up.weight") * scale_up
|
||||
)
|
||||
dims = [sample_qkv_lora_up.shape[0] // num_splits] * num_splits # 3 = num_splits
|
||||
if not is_sparse:
|
||||
for attn_k in ["add_q_proj", "add_k_proj", "add_v_proj"]:
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_A.weight"] = sample_qkv_lora_down
|
||||
for attn_k, v in zip(
|
||||
["add_q_proj", "add_k_proj", "add_v_proj"], torch.split(sample_qkv_lora_up, dims, dim=0)
|
||||
):
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = v
|
||||
else:
|
||||
# down_weight is chunked to each split
|
||||
new_state_dict.update(
|
||||
{
|
||||
f"transformer_blocks.{i}.attn.{k}.lora_A.weight": v
|
||||
for k, v in zip(
|
||||
["add_q_proj", "add_k_proj", "add_v_proj"],
|
||||
torch.chunk(sample_qkv_lora_down, num_splits, dim=0),
|
||||
)
|
||||
}
|
||||
) # noqa: C416
|
||||
|
||||
# up_weight is sparse: only non-zero values are copied to each split
|
||||
i = 0
|
||||
for j, attn_k in enumerate(["add_q_proj", "add_k_proj", "add_v_proj"]):
|
||||
new_state_dict[f"transformer_blocks.{i}.attn.{attn_k}.lora_B.weight"] = sample_qkv_lora_up[
|
||||
i : i + dims[j], j * requested_rank : (j + 1) * requested_rank
|
||||
].contiguous()
|
||||
i += dims[j]
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be at this point but has: {list(state_dict.keys())}.")
|
||||
|
||||
prefix = prefix or "transformer"
|
||||
new_state_dict = {f"{prefix}.{k}": v for k, v in new_state_dict.items()}
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
|
||||
@@ -38,6 +38,7 @@ from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_sd3_lora_to_diffusers,
|
||||
_convert_xlabs_flux_lora_to_diffusers,
|
||||
_maybe_map_sgm_blocks_to_diffusers,
|
||||
)
|
||||
@@ -1239,6 +1240,27 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_non_diffusers = any("lora_unet" in k for k in state_dict)
|
||||
if is_non_diffusers:
|
||||
has_only_transformer = all(k.startswith("lora_unet") for k in state_dict)
|
||||
if not has_only_transformer:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith("lora_unet")}
|
||||
logger.warning(
|
||||
"Some keys in the LoRA checkpoint are not related to transformer blocks and we will filter them out during loading. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
|
||||
all_joint_blocks = all("joint_blocks" in k for k in state_dict)
|
||||
if not all_joint_blocks:
|
||||
raise ValueError(
|
||||
"LoRAs containing only transformer blocks are supported at this point. Please open a new issue with the LoRA checkpoint you are trying to load with a reproducible snippet - https://github.com/huggingface/diffusers/issues/new."
|
||||
)
|
||||
|
||||
has_dual_attention_layers = any("attn2" in k for k in state_dict)
|
||||
if has_dual_attention_layers:
|
||||
raise ValueError("LoRA state dicts with dual attention layers are not supported.")
|
||||
|
||||
state_dict = _convert_non_diffusers_sd3_lora_to_diffusers(state_dict, prefix=cls.transformer_name)
|
||||
|
||||
return state_dict
|
||||
|
||||
def load_lora_weights(
|
||||
@@ -1283,12 +1305,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
|
||||
transformer_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")}
|
||||
if len(transformer_state_dict) > 0:
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1299,8 +1320,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
logger.debug("No LoRA keys were found for the transformer.")
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.text_encoder_name}.")}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
@@ -1312,8 +1335,10 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
logger.debug("No LoRA keys were found for the first text encoder.")
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if k.startswith("text_encoder_2.")}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
@@ -1325,6 +1350,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
logger.debug("No LoRA keys were found for the second text encoder.")
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
|
||||
Reference in New Issue
Block a user