Compare commits

...

2 Commits

Author SHA1 Message Date
Sayak Paul
087827ce77 Merge branch 'main' into omi-hidream-lora 2025-06-06 08:49:55 +05:30
sayakpaul
1fef4c7b8c omi lora. 2025-06-05 08:13:14 +05:30
2 changed files with 55 additions and 7 deletions

View File

@@ -1789,12 +1789,58 @@ def _convert_musubi_wan_lora_to_diffusers(state_dict):
return converted_state_dict
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
if not all(k.startswith(non_diffusers_prefix) for k in state_dict):
raise ValueError("Invalid LoRA state dict for HiDream.")
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
def _convert_non_diffusers_hidream_lora_to_diffusers(state_dict):
non_diffusers_prefix = "diffusion_model"
is_kohya = all(k.startswith(f"{non_diffusers_prefix}.") for k in state_dict)
def _convert_kohya(state_dict):
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
converted_state_dict = {f"transformer.{k}": v for k, v in converted_state_dict.items()}
return converted_state_dict
if is_kohya:
return _convert_kohya(state_dict)
else:
assert any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
converted_state_dict = {}
component = "transformer"
compoent_sd = {k: v for k, v in state_dict.items() if k.startswith(f"{component}.")}
def _convert_omi(key, state_dict, component):
down_key = f"{key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]
up_weight_key = f"{key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)
alpha_key = f"{key}.alpha"
alpha = state_dict.pop(alpha_key)
# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
diffusers_down_key = f"{key}.lora_A.weight"
converted_state_dict[f"{component}.{diffusers_down_key}"] = down_weight
converted_state_dict[f"{component}.{diffusers_down_key.replace('.lora_A.', '.lora_B.')}"] = up_weight
all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "")
for k in compoent_sd
}
for k in all_unique_keys:
_convert_omi(k, compoent_sd, component=component)
return converted_state_dict
def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):

View File

@@ -5489,7 +5489,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
kohya_format = any("diffusion_model" in k for k in state_dict)
is_omi_format = any(k.startswith(("clip_g.", "clip_l.", "t5.", "llama.", "transformer.")) for k in state_dict)
is_non_diffusers_format = kohya_format or is_omi_format
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)