Compare commits

...

6 Commits

Author SHA1 Message Date
sayakpaul
6da67f293f up 2025-12-03 20:56:34 +08:00
sayakpaul
21d8a8dc4b Merge branch 'main' into z-image-tests 2025-12-03 20:44:13 +08:00
sayakpaul
5613ff0143 up 2025-12-03 20:43:48 +08:00
Sayak Paul
d96cbacacd [tests] fix hunuyanvideo 1.5 offloading tests. (#12782)
fix hunuyanvideo 1.5 offloading tests.
2025-12-03 18:07:59 +05:30
Aditya Borate
5ab5946931 Fix: leaf_level offloading breaks after delete_adapters (#12639)
* Fix(peft): Re-apply group offloading after deleting adapters

* Test: Add regression test for group offloading + delete_adapters

* Test: Add assertions to verify output changes after deletion

* Test: Add try/finally to clean up group offloading hooks

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-12-03 17:39:11 +05:30
sayakpaul
4ca68f2f75 skipping ZImage DiT tests 2025-12-03 19:26:23 +08:00
6 changed files with 163 additions and 18 deletions

View File

@@ -22,6 +22,7 @@ from typing import Dict, List, Literal, Optional, Union
import safetensors
import torch
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import (
MIN_PEFT_VERSION,
USE_PEFT_BACKEND,
@@ -794,6 +795,8 @@ class PeftAdapterMixin:
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
_maybe_remove_and_reapply_group_offloading(self)
def enable_lora_hotswap(
self, target_rank: int = 128, check_compiled: Literal["error", "warn", "ignore"] = "error"
) -> None:

View File

@@ -15,12 +15,13 @@
import sys
import unittest
import numpy as np
import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
if is_peft_available():
@@ -29,13 +30,9 @@ if is_peft_available():
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
# @unittest.skip(
# "ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
# "and torch.empty padding tokens. LoRA functionality works correctly with real models."
# )
@require_peft_backend
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = ZImagePipeline
@@ -163,34 +160,126 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return pipeline_components, text_lora_config, denoiser_lora_config
@unittest.skip("Not supported in Flux2.")
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.transformer.delete_adapters("adapter-1")
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, _ in denoiser.named_modules():
if "to_k" in name and "attention" in name and "lora" not in name:
module_name_to_rank_update = name.replace(".base_layer.", ".")
break
# change the rank_pattern
updated_rank = denoiser_lora_config.r * 2
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
pipe.transformer.delete_adapters("adapter-1")
# similarly change the alpha_pattern
updated_alpha = denoiser_lora_config.lora_alpha * 2
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
)
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
@skip_mps
def test_lora_fuse_nan(self):
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
possible_tower_names = ["noise_refiner"]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
@unittest.skip("Needs to be debugged.")
def test_set_adapters_match_attention_kwargs(self):
super().test_set_adapters_match_attention_kwargs()
@unittest.skip("Needs to be debugged.")
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux2.")
@unittest.skip("Not supported in ZImage.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux2.")
@unittest.skip("Not supported in ZImage.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -28,6 +28,7 @@ from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers.hooks.group_offloading import _GROUP_OFFLOADING, apply_group_offloading
from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
@@ -2367,3 +2368,51 @@ class PeftLoraLoaderMixinTests:
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
@require_torch_accelerator
def test_lora_group_offloading_delete_adapters(self):
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
try:
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
pipe.to(torch_device)
# Enable Group Offloading (leaf_level for more granular testing)
apply_group_offloading(
denoiser,
onload_device=torch_device,
offload_device="cpu",
offload_type="leaf_level",
)
pipe.load_lora_weights(tmpdirname, adapter_name="default")
out_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Delete the adapter
pipe.delete_adapters("default")
out_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(out_lora, out_no_lora, atol=1e-3, rtol=1e-3))
finally:
# Clean up the hooks to prevent state leak
if hasattr(denoiser, "_diffusers_hook"):
denoiser._diffusers_hook.remove_hook(_GROUP_OFFLOADING, recurse=True)

View File

@@ -29,6 +29,7 @@ class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideo15Transformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.99, 0.99, 0.99]
text_embed_dim = 16
text_embed_2_dim = 8

View File

@@ -21,7 +21,7 @@ import torch
from diffusers import ZImageTransformer2DModel
from ...testing_utils import torch_device
from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
@@ -36,6 +36,10 @@ if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
@unittest.skipIf(
IS_GITHUB_ACTIONS,
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
)
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = ZImageTransformer2DModel
main_input_name = "x"

View File

@@ -22,7 +22,7 @@ from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
from ...testing_utils import is_flaky, torch_device
from ...testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -170,7 +170,6 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
return inputs
@is_flaky(max_attempts=10)
def test_inference(self):
device = "cpu"
@@ -185,7 +184,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453])
# fmt: on
generated_slice = generated_image.flatten()