mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
custom-cod
...
omi-hidrea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
087827ce77 | ||
|
|
1fef4c7b8c |
@@ -1789,12 +1789,58 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
|
||||
raise ValueError("Invalid LoRA state dict for HiDream.")
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict):
|
||||
non_diffusers_prefix = "diffusion_model"
|
||||
is_kohya = all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict)
|
||||
|
||||
def _convert_kohya(state_dict):
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
|
||||
return converted_state_dict
|
||||
|
||||
if is_kohya:
|
||||
return _convert_kohya(state_dict)
|
||||
|
||||
else:
|
||||
assert any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
|
||||
converted_state_dict = {}
|
||||
component = "transformer"
|
||||
compoent_sd = {k: v for k, v in state_dict.items() if k.startswith(f"{component}.")}
|
||||
|
||||
def _convert_omi(key, state_dict, component):
|
||||
down_key = f"{key}.lora_down.weight"
|
||||
down_weight = state_dict.pop(down_key)
|
||||
lora_rank = down_weight.shape[0]
|
||||
|
||||
up_weight_key = f"{key}.lora_up.weight"
|
||||
up_weight = state_dict.pop(up_weight_key)
|
||||
|
||||
alpha_key = f"{key}.alpha"
|
||||
alpha = state_dict.pop(alpha_key)
|
||||
|
||||
# scale weight by alpha and dim
|
||||
scale = alpha / lora_rank
|
||||
# calculate scale_down and scale_up
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
down_weight = down_weight * scale_down
|
||||
up_weight = up_weight * scale_up
|
||||
|
||||
diffusers_down_key = f"{key}.lora_A.weight"
|
||||
converted_state_dict[f"{component}.{diffusers_down_key}"] = down_weight
|
||||
converted_state_dict[f"{component}.{diffusers_down_key.replace('.lora_A.', '.lora_B.')}"] = up_weight
|
||||
|
||||
all_unique_keys = {
|
||||
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
|
||||
for k in compoent_sd
|
||||
}
|
||||
for k in all_unique_keys:
|
||||
_convert_omi(k, compoent_sd, component=component)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
|
||||
@@ -5489,7 +5489,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
|
||||
kohya_format = any("diffusion_model" in k for k in state_dict)
|
||||
is_omi_format = any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
|
||||
is_non_diffusers_format = kohya_format or is_omi_format
|
||||
if is_non_diffusers_format:
|
||||
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user