Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
aff05dc742 refactor non-diffusers lora conversion utility. 2024-05-24 16:25:22 +05:30
2 changed files with 148 additions and 146 deletions

View File

@@ -45,7 +45,7 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
if is_transformers_available(): if is_transformers_available():
@@ -302,7 +302,7 @@ class LoraLoaderMixin:
if unet_config is not None: if unet_config is not None:
# use unet config to remap block numbers # use unet config to remap block numbers
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas return state_dict, network_alphas

View File

@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
import re import re
from typing import Any, Dict, Tuple
import torch
from ..utils import is_peft_version, logging from ..utils import is_peft_version, logging
@@ -123,164 +126,163 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
return new_state_dict return new_state_dict
def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"): def _convert_non_diffusers_lora_to_diffusers(
unet_state_dict = {} state_dict: Dict[str, torch.Tensor], unet_name: str = "unet", text_encoder_name: str = "text_encoder"
te_state_dict = {} ) -> Tuple[Dict[str, Any], Dict[str, float]]:
te2_state_dict = {} def detect_dora_lora(state_dict: Dict[str, torch.Tensor]) -> Tuple[bool, bool, bool]:
network_alphas = {}
is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict) is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict) is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict) is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
return is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora
def check_peft_version(is_unet_dora_lora: bool, is_te_dora_lora: bool, is_te2_dora_lora: bool):
if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora: if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
if is_peft_version("<", "0.9.0"): if is_peft_version("<", "0.9.0"):
raise ValueError( raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`." "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
) )
# every down weight has a corresponding up weight and potentially an alpha weight def rename_keys(
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] state_dict: Dict[str, torch.Tensor],
for key in lora_keys: key: str,
unet_state_dict: Dict[str, torch.Tensor],
te_state_dict: Dict[str, torch.Tensor],
te2_state_dict: Dict[str, torch.Tensor],
is_unet_dora_lora: bool,
is_te_dora_lora: bool,
is_te2_dora_lora: bool,
):
lora_name = key.split(".")[0] lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight" lora_name_up = lora_name + ".lora_up.weight"
lora_name_alpha = lora_name + ".alpha" diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
lora_type = lora_name.split("_")[1]
if lora_name.startswith("lora_unet_"): if lora_type == "unet":
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") diffusers_name = _adjust_unet_names(diffusers_name)
unet_state_dict = _populate_state_dict(
if "input.blocks" in diffusers_name: unet_state_dict, state_dict, key, lora_name_up, diffusers_name, is_unet_dora_lora
diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") )
else: else:
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") diffusers_name = _adjust_text_encoder_names(diffusers_name)
if lora_type in ["te", "te1"]:
if "middle.block" in diffusers_name: te_state_dict = _populate_state_dict(
diffusers_name = diffusers_name.replace("middle.block", "mid_block") te_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te_dora_lora
else: )
diffusers_name = diffusers_name.replace("mid.block", "mid_block") else:
if "output.blocks" in diffusers_name: te2_state_dict = _populate_state_dict(
diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") te2_state_dict, state_dict, key, lora_name_up, diffusers_name, is_te2_dora_lora
else:
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
diffusers_name = diffusers_name.replace("proj.in", "proj_in")
diffusers_name = diffusers_name.replace("proj.out", "proj_out")
diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
# SDXL specificity.
if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
pattern = r"\.\d+(?=\D*$)"
diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
if ".in." in diffusers_name:
diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
if ".out." in diffusers_name:
diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
diffusers_name = diffusers_name.replace("op", "conv")
if "skip" in diffusers_name:
diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
# LyCORIS specificity.
if "time.emb.proj" in diffusers_name:
diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
if "conv.shortcut" in diffusers_name:
diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
# General coverage.
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "ff" in diffusers_name:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
unet_state_dict[diffusers_name] = state_dict.pop(key)
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if is_unet_dora_lora:
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
unet_state_dict[
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
if lora_name.startswith(("lora_te_", "lora_te1_")):
key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
else:
key_to_replace = "lora_te2_"
diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
elif "mlp" in diffusers_name:
# Be aware that this is the new diffusers convention and the rest of the code might
# not utilize it yet.
diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[diffusers_name] = state_dict.pop(key)
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
else:
te2_state_dict[diffusers_name] = state_dict.pop(key)
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
dora_scale_key_to_replace_te = (
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
) )
if lora_name.startswith(("lora_te_", "lora_te1_")):
te_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
elif lora_name.startswith("lora_te2_"):
te2_state_dict[
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
# Rename the alphas so that they can be mapped appropriately. return unet_state_dict, te_state_dict, te2_state_dict
def _adjust_unet_names(name: str) -> str:
replacements = [
("input.blocks", "down_blocks"),
("down.blocks", "down_blocks"),
("middle.block", "mid_block"),
("mid.block", "mid_block"),
("output.blocks", "up_blocks"),
("up.blocks", "up_blocks"),
("transformer.blocks", "transformer_blocks"),
("to.q.lora", "to_q_lora"),
("to.k.lora", "to_k_lora"),
("to.v.lora", "to_v_lora"),
("to.out.0.lora", "to_out_lora"),
("proj.in", "proj_in"),
("proj.out", "proj_out"),
("emb.layers", "time_emb_proj"),
("time.emb.proj", "time_emb_proj"),
("conv.shortcut", "conv_shortcut"),
("skip.connection", "conv_shortcut"),
]
for old, new in replacements:
name = name.replace(old, new)
if "emb" in name and "time.emb.proj" not in name:
pattern = r"\.\d+(?=\D*$)"
name = re.sub(pattern, "", name, count=1)
if ".in." in name:
name = name.replace("in.layers.2", "conv1")
if ".out." in name:
name = name.replace("out.layers.3", "conv2")
if "downsamplers" in name or "upsamplers" in name:
name = name.replace("op", "conv")
return name
def _adjust_text_encoder_names(name: str) -> str:
replacements = [
("text.model", "text_model"),
("self.attn", "self_attn"),
("q.proj.lora", "to_q_lora"),
("k.proj.lora", "to_k_lora"),
("v.proj.lora", "to_v_lora"),
("out.proj.lora", "to_out_lora"),
("text.projection", "text_projection"),
]
for old, new in replacements:
name = name.replace(old, new)
return name
def _populate_state_dict(state_dict, main_dict, down_key, up_key, name, is_dora_lora):
state_dict[name] = main_dict.pop(down_key)
state_dict[name.replace(".down.", ".up.")] = main_dict.pop(up_key)
if is_dora_lora:
dora_key = down_key.replace("lora_down.weight", "dora_scale")
scale_key = "_lora.down." if "_lora.down." in name else ".lora.down."
state_dict[name.replace(scale_key, ".lora_magnitude_vector.")] = main_dict.pop(dora_key)
return state_dict
def update_network_alphas(
state_dict: Dict[str, torch.Tensor],
network_alphas: Dict[str, float],
diffusers_name: str,
lora_name_alpha: str,
):
if lora_name_alpha in state_dict: if lora_name_alpha in state_dict:
alpha = state_dict.pop(lora_name_alpha).item() alpha = state_dict.pop(lora_name_alpha).item()
if lora_name_alpha.startswith("lora_unet_"): prefix = (
prefix = "unet." "unet."
elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): if "unet" in lora_name_alpha
prefix = "text_encoder." else "text_encoder."
else: if "te1" in lora_name_alpha
prefix = "text_encoder_2." else "text_encoder_2."
)
new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
network_alphas.update({new_name: alpha}) network_alphas.update({new_name: alpha})
if len(state_dict) > 0: unet_state_dict = {}
raise ValueError(f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}") te_state_dict = {}
te2_state_dict = {}
network_alphas = {}
logger.info("Kohya-style checkpoint detected.") is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora = detect_dora_lora(state_dict)
check_peft_version(is_unet_dora_lora, is_te_dora_lora, is_te2_dora_lora)
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
for key in lora_keys:
unet_state_dict, te_state_dict, te2_state_dict = rename_keys(
state_dict,
key,
unet_state_dict,
te_state_dict,
te2_state_dict,
is_unet_dora_lora,
is_te_dora_lora,
is_te2_dora_lora,
)
lora_name = key.split(".")[0]
lora_name_alpha = lora_name + ".alpha"
diffusers_name = key.replace(lora_name + ".", "").replace("_", ".")
update_network_alphas(state_dict, network_alphas, diffusers_name, lora_name_alpha)
if state_dict:
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
logger.info("Non-diffusers LoRA checkpoint detected.")
unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()} te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
te2_state_dict = (
{f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} if te2_state_dict:
if len(te2_state_dict) > 0 te2_state_dict = {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()}
else None
)
if te2_state_dict is not None:
te_state_dict.update(te2_state_dict) te_state_dict.update(te2_state_dict)
new_state_dict = {**unet_state_dict, **te_state_dict} new_state_dict = {**unet_state_dict, **te_state_dict}