mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 05:14:55 +08:00
Compare commits
2 Commits
remove-unn
...
omi-hidrea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
087827ce77 | ||
|
|
1fef4c7b8c |
@@ -1789,12 +1789,58 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
|||||||
return converted_state_dict
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict):
|
||||||
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
|
non_diffusers_prefix = "diffusion_model"
|
||||||
raise ValueError("Invalid LoRA state dict for HiDream.")
|
is_kohya = all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict)
|
||||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
|
||||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
def _convert_kohya(state_dict):
|
||||||
return converted_state_dict
|
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||||
|
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
if is_kohya:
|
||||||
|
return _convert_kohya(state_dict)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
|
||||||
|
converted_state_dict = {}
|
||||||
|
component = "transformer"
|
||||||
|
compoent_sd = {k: v for k, v in state_dict.items() if k.startswith(f"{component}.")}
|
||||||
|
|
||||||
|
def _convert_omi(key, state_dict, component):
|
||||||
|
down_key = f"{key}.lora_down.weight"
|
||||||
|
down_weight = state_dict.pop(down_key)
|
||||||
|
lora_rank = down_weight.shape[0]
|
||||||
|
|
||||||
|
up_weight_key = f"{key}.lora_up.weight"
|
||||||
|
up_weight = state_dict.pop(up_weight_key)
|
||||||
|
|
||||||
|
alpha_key = f"{key}.alpha"
|
||||||
|
alpha = state_dict.pop(alpha_key)
|
||||||
|
|
||||||
|
# scale weight by alpha and dim
|
||||||
|
scale = alpha / lora_rank
|
||||||
|
# calculate scale_down and scale_up
|
||||||
|
scale_down = scale
|
||||||
|
scale_up = 1.0
|
||||||
|
while scale_down * 2 < scale_up:
|
||||||
|
scale_down *= 2
|
||||||
|
scale_up /= 2
|
||||||
|
down_weight = down_weight * scale_down
|
||||||
|
up_weight = up_weight * scale_up
|
||||||
|
|
||||||
|
diffusers_down_key = f"{key}.lora_A.weight"
|
||||||
|
converted_state_dict[f"{component}.{diffusers_down_key}"] = down_weight
|
||||||
|
converted_state_dict[f"{component}.{diffusers_down_key.replace('.lora_A.', '.lora_B.')}"] = up_weight
|
||||||
|
|
||||||
|
all_unique_keys = {
|
||||||
|
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
|
||||||
|
for k in compoent_sd
|
||||||
|
}
|
||||||
|
for k in all_unique_keys:
|
||||||
|
_convert_omi(k, compoent_sd, component=component)
|
||||||
|
|
||||||
|
return converted_state_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||||
|
|||||||
@@ -5489,7 +5489,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
|||||||
logger.warning(warn_msg)
|
logger.warning(warn_msg)
|
||||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||||
|
|
||||||
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
kohya_format = any("diffusion_model" in k for k in state_dict)
|
||||||
|
is_omi_format = any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
|
||||||
|
is_non_diffusers_format = kohya_format or is_omi_format
|
||||||
if is_non_diffusers_format:
|
if is_non_diffusers_format:
|
||||||
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user