mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 05:54:24 +08:00
Compare commits
1 Commits
add-uv-scr
...
dora-fixes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
32b1a6fab4 |
@@ -153,12 +153,9 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
|
||||
# Iterate over all LoRA weights.
|
||||
all_lora_keys = list(state_dict.keys())
|
||||
for key in all_lora_keys:
|
||||
if not key.endswith("lora_down.weight"):
|
||||
continue
|
||||
|
||||
# every down weight has a corresponding up weight and potentially an alpha weight
|
||||
lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
|
||||
for key in lora_keys:
|
||||
# Extract LoRA name.
|
||||
lora_name = key.split(".")[0]
|
||||
|
||||
@@ -177,9 +174,12 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_unet:
|
||||
dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
|
||||
unet_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
new_key = diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
|
||||
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
dora_weight = state_dict.pop(lora_name + ".dora_scale")
|
||||
if dora_weight.dim() <= 2:
|
||||
dora_weight = dora_weight.squeeze()
|
||||
unet_state_dict[new_key] = dora_weight
|
||||
|
||||
# Handle text encoder LoRAs.
|
||||
elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
|
||||
@@ -194,18 +194,24 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
|
||||
|
||||
# Store DoRA scale if present.
|
||||
if dora_present_in_te or dora_present_in_te2:
|
||||
if (dora_present_in_te or dora_present_in_te2):
|
||||
dora_scale_key_to_replace_te = (
|
||||
"_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
|
||||
)
|
||||
if lora_name.startswith(("lora_te_", "lora_te1_")):
|
||||
te_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
dora_weight = state_dict.pop(lora_name + ".dora_scale")
|
||||
if dora_weight.dim() <= 2:
|
||||
dora_weight = dora_weight.squeeze()
|
||||
te_state_dict[new_key] = dora_weight
|
||||
elif lora_name.startswith("lora_te2_"):
|
||||
te2_state_dict[
|
||||
diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
new_key = diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
|
||||
# dora_weight = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
|
||||
dora_weight = state_dict.pop(lora_name + ".dora_scale")
|
||||
if dora_weight.dim() <= 2:
|
||||
dora_weight = dora_weight.squeeze()
|
||||
te2_state_dict[new_key] = dora_weight
|
||||
|
||||
# Store alpha if present.
|
||||
if lora_name_alpha in state_dict:
|
||||
@@ -214,7 +220,8 @@ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_
|
||||
|
||||
# Check if any keys remain.
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
|
||||
all_keys_remaining = sorted(list(state_dict.keys()))
|
||||
raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(all_keys_remaining)}")
|
||||
|
||||
logger.info("Non-diffusers checkpoint detected.")
|
||||
|
||||
@@ -285,7 +292,7 @@ def _convert_unet_lora_key(key):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
return diffusers_name
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user