mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
1 Commits
remove-lor
...
refactor-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aff05dc742 |
@@ -45,7 +45,7 @@ from ..utils import (
|
|||||||
set_adapter_layers,
|
set_adapter_layers,
|
||||||
set_weights_and_activate_adapters,
|
set_weights_and_activate_adapters,
|
||||||
)
|
)
|
||||||
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
@@ -302,7 +302,7 @@ class LoraLoaderMixin:
|
|||||||
if unet_config is not None:
|
if unet_config is not None:
|
||||||
# use unet config to remap block numbers
|
# use unet config to remap block numbers
|
||||||
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
|
||||||
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
|
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
return state_dict, network_alphas
|
return state_dict, network_alphas
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from ..utils import is_peft_version, logging
|
from ..utils import is_peft_version, logging
|
||||||
|
|
||||||
@@ -123,164 +126,163 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
|
def _convert_non_diffusers_lora_to_diffusers(
|
||||||
unet_state_dict = {}
|
state_dict: Dict[str, torch.Tensor], unet_name: str = "unet", text_encoder_name: str = "text_encoder"
|
||||||
te_state_dict = {}
|
) -> Tuple[Dict[str, Any], Dict[str, float]]:
|
||||||
te2_state_dict = {}
|
def detect_dora_lora(state_dict: Dict[str, torch.Tensor]) -> Tuple[bool, bool, bool]:
|
||||||
network_alphas = {}
|
|
||||||
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
|
||||||
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
|
||||||
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
|
||||||
|
return is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora
|
||||||
|
|
||||||
|
def check_peft_version(is_unet_dora_lora: bool, is_te_dora_lora: bool, is_te2_dora_lora: bool):
|
||||||
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
|
||||||
if is_peft_version("<", "0.9.0"):
|
if is_peft_version("<", "0.9.0"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
def rename_keys(
|
||||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
state_dict: Dict[str, torch.Tensor],
|
||||||
for key in lora_keys:
|
key: str,
|
||||||
|
unet_state_dict: Dict[str, torch.Tensor],
|
||||||
|
te_state_dict: Dict[str, torch.Tensor],
|
||||||
|
te2_state_dict: Dict[str, torch.Tensor],
|
||||||
|
is_unet_dora_lora: bool,
|
||||||
|
is_te_dora_lora: bool,
|
||||||
|
is_te2_dora_lora: bool,
|
||||||
|
):
|
||||||
lora_name = key.split(".")[0]
|
lora_name = key.split(".")[0]
|
||||||
lora_name_up = lora_name + ".lora_up.weight"
|
lora_name_up = lora_name + ".lora_up.weight"
|
||||||
lora_name_alpha = lora_name + ".alpha"
|
diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
|
||||||
|
lora_type = lora_name.split("_")[1]
|
||||||
|
|
||||||
if lora_name.startswith("lora_unet_"):
|
if lora_type == "unet":
|
||||||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
|
diffusers_name = _adjust_unet_names(diffusers_name)
|
||||||
|
unet_state_dict = _populate_state_dict(
|
||||||
if "input.blocks" in diffusers_name:
|
unet_state_dict, state_dict, key, lora_name_up, diffusers_name, is_unet_dora_lora
|
||||||
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
|
)
|
||||||
else:
|
else:
|
||||||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
|
diffusers_name = _adjust_text_encoder_names(diffusers_name)
|
||||||
|
if lora_type in ["te", "te1"]:
|
||||||
if "middle.block" in diffusers_name:
|
te_state_dict = _populate_state_dict(
|
||||||
diffusers_name = diffusers_name.replace("middle.block", "mid_block")
|
te_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te_dora_lora
|
||||||
else:
|
)
|
||||||
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
|
else:
|
||||||
if "output.blocks" in diffusers_name:
|
te2_state_dict = _populate_state_dict(
|
||||||
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
|
te2_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te2_dora_lora
|
||||||
else:
|
|
||||||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
|
|
||||||
|
|
||||||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
|
|
||||||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
|
|
||||||
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
|
|
||||||
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
|
|
||||||
|
|
||||||
# SDXL specificity.
|
|
||||||
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
|
|
||||||
pattern = r"\.\d+(?=\D*$)"
|
|
||||||
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
|
|
||||||
if ".in." in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
|
|
||||||
if ".out." in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
|
|
||||||
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("op", "conv")
|
|
||||||
if "skip" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
|
|
||||||
|
|
||||||
# LyCORIS specificity.
|
|
||||||
if "time.emb.proj" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
|
|
||||||
if "conv.shortcut" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
|
|
||||||
|
|
||||||
# General coverage.
|
|
||||||
if "transformer_blocks" in diffusers_name:
|
|
||||||
if "attn1" in diffusers_name or "attn2" in diffusers_name:
|
|
||||||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
|
|
||||||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif "ff" in diffusers_name:
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
unet_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
|
|
||||||
if is_unet_dora_lora:
|
|
||||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
|
||||||
unet_state_dict[
|
|
||||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
|
||||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
|
||||||
|
|
||||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
|
|
||||||
else:
|
|
||||||
key_to_replace = "lora_te2_"
|
|
||||||
|
|
||||||
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
|
|
||||||
diffusers_name = diffusers_name.replace("text.model", "text_model")
|
|
||||||
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
|
|
||||||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
|
|
||||||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
|
|
||||||
if "self_attn" in diffusers_name:
|
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
elif "mlp" in diffusers_name:
|
|
||||||
# Be aware that this is the new diffusers convention and the rest of the code might
|
|
||||||
# not utilize it yet.
|
|
||||||
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
|
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
te_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
else:
|
|
||||||
te2_state_dict[diffusers_name] = state_dict.pop(key)
|
|
||||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
|
||||||
|
|
||||||
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
|
||||||
dora_scale_key_to_replace_te = (
|
|
||||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
|
||||||
)
|
)
|
||||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
|
||||||
te_state_dict[
|
|
||||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
|
||||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
|
||||||
elif lora_name.startswith("lora_te2_"):
|
|
||||||
te2_state_dict[
|
|
||||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
|
||||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
|
||||||
|
|
||||||
# Rename the alphas so that they can be mapped appropriately.
|
return unet_state_dict, te_state_dict, te2_state_dict
|
||||||
|
|
||||||
|
def _adjust_unet_names(name: str) -> str:
|
||||||
|
replacements = [
|
||||||
|
("input.blocks", "down_blocks"),
|
||||||
|
("down.blocks", "down_blocks"),
|
||||||
|
("middle.block", "mid_block"),
|
||||||
|
("mid.block", "mid_block"),
|
||||||
|
("output.blocks", "up_blocks"),
|
||||||
|
("up.blocks", "up_blocks"),
|
||||||
|
("transformer.blocks", "transformer_blocks"),
|
||||||
|
("to.q.lora", "to_q_lora"),
|
||||||
|
("to.k.lora", "to_k_lora"),
|
||||||
|
("to.v.lora", "to_v_lora"),
|
||||||
|
("to.out.0.lora", "to_out_lora"),
|
||||||
|
("proj.in", "proj_in"),
|
||||||
|
("proj.out", "proj_out"),
|
||||||
|
("emb.layers", "time_emb_proj"),
|
||||||
|
("time.emb.proj", "time_emb_proj"),
|
||||||
|
("conv.shortcut", "conv_shortcut"),
|
||||||
|
("skip.connection", "conv_shortcut"),
|
||||||
|
]
|
||||||
|
for old, new in replacements:
|
||||||
|
name = name.replace(old, new)
|
||||||
|
if "emb" in name and "time.emb.proj" not in name:
|
||||||
|
pattern = r"\.\d+(?=\D*$)"
|
||||||
|
name = re.sub(pattern, "", name, count=1)
|
||||||
|
if ".in." in name:
|
||||||
|
name = name.replace("in.layers.2", "conv1")
|
||||||
|
if ".out." in name:
|
||||||
|
name = name.replace("out.layers.3", "conv2")
|
||||||
|
if "downsamplers" in name or "upsamplers" in name:
|
||||||
|
name = name.replace("op", "conv")
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _adjust_text_encoder_names(name: str) -> str:
|
||||||
|
replacements = [
|
||||||
|
("text.model", "text_model"),
|
||||||
|
("self.attn", "self_attn"),
|
||||||
|
("q.proj.lora", "to_q_lora"),
|
||||||
|
("k.proj.lora", "to_k_lora"),
|
||||||
|
("v.proj.lora", "to_v_lora"),
|
||||||
|
("out.proj.lora", "to_out_lora"),
|
||||||
|
("text.projection", "text_projection"),
|
||||||
|
]
|
||||||
|
for old, new in replacements:
|
||||||
|
name = name.replace(old, new)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def _populate_state_dict(state_dict, main_dict, down_key, up_key, name, is_dora_lora):
|
||||||
|
state_dict[name] = main_dict.pop(down_key)
|
||||||
|
state_dict[name.replace(".down.", ".up.")] = main_dict.pop(up_key)
|
||||||
|
if is_dora_lora:
|
||||||
|
dora_key = down_key.replace("lora_down.weight", "dora_scale")
|
||||||
|
scale_key = "_lora.down." if "_lora.down." in name else ".lora.down."
|
||||||
|
state_dict[name.replace(scale_key, ".lora_magnitude_vector.")] = main_dict.pop(dora_key)
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def update_network_alphas(
|
||||||
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
network_alphas: Dict[str, float],
|
||||||
|
diffusers_name: str,
|
||||||
|
lora_name_alpha: str,
|
||||||
|
):
|
||||||
if lora_name_alpha in state_dict:
|
if lora_name_alpha in state_dict:
|
||||||
alpha = state_dict.pop(lora_name_alpha).item()
|
alpha = state_dict.pop(lora_name_alpha).item()
|
||||||
if lora_name_alpha.startswith("lora_unet_"):
|
prefix = (
|
||||||
prefix = "unet."
|
"unet."
|
||||||
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
|
if "unet" in lora_name_alpha
|
||||||
prefix = "text_encoder."
|
else "text_encoder."
|
||||||
else:
|
if "te1" in lora_name_alpha
|
||||||
prefix = "text_encoder_2."
|
else "text_encoder_2."
|
||||||
|
)
|
||||||
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
|
||||||
network_alphas.update({new_name: alpha})
|
network_alphas.update({new_name: alpha})
|
||||||
|
|
||||||
if len(state_dict) > 0:
|
unet_state_dict = {}
|
||||||
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}")
|
te_state_dict = {}
|
||||||
|
te2_state_dict = {}
|
||||||
|
network_alphas = {}
|
||||||
|
|
||||||
logger.info("Kohya-style checkpoint detected.")
|
is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora = detect_dora_lora(state_dict)
|
||||||
|
check_peft_version(is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora)
|
||||||
|
|
||||||
|
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||||
|
for key in lora_keys:
|
||||||
|
unet_state_dict, te_state_dict, te2_state_dict = rename_keys(
|
||||||
|
state_dict,
|
||||||
|
key,
|
||||||
|
unet_state_dict,
|
||||||
|
te_state_dict,
|
||||||
|
te2_state_dict,
|
||||||
|
is_unet_dora_lora,
|
||||||
|
is_te_dora_lora,
|
||||||
|
is_te2_dora_lora,
|
||||||
|
)
|
||||||
|
lora_name = key.split(".")[0]
|
||||||
|
lora_name_alpha = lora_name + ".alpha"
|
||||||
|
diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
|
||||||
|
update_network_alphas(state_dict, network_alphas, diffusers_name, lora_name_alpha)
|
||||||
|
|
||||||
|
if state_dict:
|
||||||
|
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||||
|
|
||||||
|
logger.info("Non-diffusers LoRA checkpoint detected.")
|
||||||
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
|
||||||
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
|
||||||
te2_state_dict = (
|
|
||||||
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
if te2_state_dict:
|
||||||
if len(te2_state_dict) > 0
|
te2_state_dict = {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
|
||||||
else None
|
|
||||||
)
|
|
||||||
if te2_state_dict is not None:
|
|
||||||
te_state_dict.update(te2_state_dict)
|
te_state_dict.update(te2_state_dict)
|
||||||
|
|
||||||
new_state_dict = {**unet_state_dict, **te_state_dict}
|
new_state_dict = {**unet_state_dict, **te_state_dict}
|
||||||
|
|||||||
Reference in New Issue
Block a user