Compare commits

...

6 Commits

Author SHA1 Message Date
Daniel Gu
150047f51a Further improve incorrect LoRA format error msg following review 2026-01-12 23:59:44 +01:00
Sayak Paul
23db559608 Merge branch 'main' into improve-lora-loaders 2026-01-12 09:03:44 +05:30
Daniel Gu
0b10746140 Apply changes to LTX2LoraTests 2026-01-10 08:02:24 +01:00
Daniel Gu
bd91810f4c Merge branch 'main' into improve-lora-loaders 2026-01-10 07:55:28 +01:00
Daniel Gu
dc43efbc4c Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests 2026-01-10 07:54:27 +01:00
Daniel Gu
51dc061ee6 Improve incorrect LoRA format error message 2026-01-10 06:15:36 +01:00
16 changed files with 67 additions and 304 deletions

View File

@@ -214,7 +214,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -641,7 +641,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_unet(
state_dict,
@@ -1081,7 +1081,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1377,7 +1377,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -1659,7 +1659,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_lora_state_dict = {
k: state_dict.get(k)
@@ -2506,7 +2506,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2703,7 +2703,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -2906,7 +2906,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3115,7 +3115,7 @@ class LTX2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
@@ -3333,7 +3333,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3536,7 +3536,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3740,7 +3740,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -3940,7 +3940,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4194,7 +4194,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4471,7 +4471,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4691,7 +4691,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -4894,7 +4894,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5100,7 +5100,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5306,7 +5306,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
@@ -5509,7 +5509,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,

View File

@@ -76,6 +76,8 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -114,23 +116,3 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -87,6 +87,8 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -147,26 +149,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -85,6 +85,8 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder",
)
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -162,23 +164,3 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -66,6 +66,8 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Mistral3ForConditionalGeneration, "hf-internal-testing/tiny-mistral3-diffusers"
denoiser_target_modules = ["to_qkv_mlp_proj", "to_k"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -146,23 +148,3 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -117,6 +117,8 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"text_encoder_2",
)
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -172,26 +174,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@nightly
@require_torch_accelerator

View File

@@ -150,6 +150,8 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
denoiser_target_modules = ["to_q", "to_k", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 5, 32, 32, 3)
@@ -267,27 +269,3 @@ class LTX2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in LTX2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTX2.")
def test_simple_inference_save_pretrained_with_text_lora(self):
pass

View File

@@ -76,6 +76,8 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -125,23 +127,3 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -74,6 +74,8 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = GemmaForCausalLM, "hf-internal-testing/dummy-gemma-diffusers"
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 4, 4, 3)
@@ -113,26 +115,6 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@skip_mps
@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cpu" and is_torch_version(">=", "2.5"),

View File

@@ -67,6 +67,8 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 7, 16, 16, 3)
@@ -117,26 +119,6 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -69,6 +69,8 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 8, 8, 3)
@@ -107,23 +109,3 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -75,6 +75,8 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
tokenizer_cls, tokenizer_id = GemmaTokenizer, "hf-internal-testing/dummy-gemma"
text_encoder_cls, text_encoder_id = Gemma2Model, "hf-internal-testing/dummy-gemma-for-diffusers"
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -117,26 +119,6 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()

View File

@@ -73,6 +73,8 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 32, 32, 3)
@@ -121,23 +123,3 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in Wan.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -85,6 +85,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 9, 16, 16, 3)
@@ -139,26 +141,6 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_save_load(self):
pass
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()

View File

@@ -75,6 +75,8 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
text_encoder_cls, text_encoder_id = Qwen3Model, None # Will be created inline
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 32, 32, 3)
@@ -263,23 +265,3 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -117,6 +117,7 @@ class PeftLoraLoaderMixinTests:
tokenizer_cls, tokenizer_id, tokenizer_subfolder = None, None, ""
tokenizer_2_cls, tokenizer_2_id, tokenizer_2_subfolder = None, None, ""
tokenizer_3_cls, tokenizer_3_id, tokenizer_3_subfolder = None, None, ""
supports_text_encoder_loras = True
unet_kwargs = None
transformer_cls = None
@@ -333,6 +334,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -457,6 +461,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
@@ -494,6 +501,9 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -555,6 +565,9 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -593,6 +606,9 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, _, _ = self.get_dummy_components()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
@@ -651,6 +667,9 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
if not self.supports_text_encoder_loras:
pytest.skip("Skipping test as text encoder LoRAs are not currently supported.")
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)