Compare commits

...

2 Commits

Author SHA1 Message Date
Sayak Paul
4749c4f74a Merge branch 'main' into z-image-distillation-lora 2026-02-12 21:21:37 +05:30
Álvaro Somoza
2b16351270 conversion 2026-02-11 16:05:55 -03:00

View File

@@ -2455,18 +2455,22 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_diffusion_model:
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") or k.startswith("lora_unet__") for k in state_dict)
if has_lora_unet:
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
state_dict = {k.removeprefix("lora_unet__").removeprefix("lora_unet_"): v for k, v in state_dict.items()}
def convert_key(key: str) -> str:
# ZImage has: layers, noise_refiner, context_refiner blocks
# Keys may be like: layers_0_attention_to_q.lora_down.weight
if "." in key:
base, suffix = key.rsplit(".", 1)
else:
base, suffix = key, ""
suffix = ""
for sfx in (".lora_down.weight", ".lora_up.weight", ".alpha"):
if key.endswith(sfx):
base = key[: -len(sfx)]
suffix = sfx
break
else:
base = key
# Protected n-grams that must keep their internal underscores
protected = {
@@ -2477,6 +2481,9 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
("to", "out"),
# feed_forward
("feed", "forward"),
# noise and context refiner
("noise", "refiner"),
("context", "refiner"),
}
prot_by_len = {}
@@ -2501,7 +2508,7 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
i += 1
converted_base = ".".join(merged)
return converted_base + (("." + suffix) if suffix else "")
return converted_base + suffix
state_dict = {convert_key(k): v for k, v in state_dict.items()}