Compare commits

..

5 Commits

Author SHA1 Message Date
DN6
ea2f64b5c0 update 2026-03-09 16:27:42 +05:30
DN6
6fdc0e25d2 Merge branch 'main' into qwen-test-refactor 2026-03-09 16:20:50 +05:30
Sayak Paul
d82f91329c Merge branch 'main' into qwen-test-refactor 2026-02-15 13:40:59 +05:30
Dhruv Nair
ffdfe28983 update 2026-02-03 06:05:12 +01:00
Dhruv Nair
31ed009706 update 2026-02-02 15:48:06 +01:00
3 changed files with 281 additions and 201 deletions

View File

@@ -16,9 +16,6 @@ on:
branches:
- ci-*
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

View File

@@ -14,15 +14,16 @@
import contextlib
import gc
import logging
import unittest
import pytest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions
from ..testing_utils import (
@@ -218,18 +219,20 @@ class NestedContainer(torch.nn.Module):
@require_torch_accelerator
class TestGroupOffload:
class GroupOffloadTests(unittest.TestCase):
in_features = 64
hidden_features = 256
out_features = 64
num_layers = 4
def setup_method(self):
def setUp(self):
with torch.no_grad():
self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device)
def teardown_method(self):
def tearDown(self):
super().tearDown()
del self.model
del self.input
gc.collect()
@@ -245,20 +248,18 @@ class TestGroupOffload:
num_layers=self.num_layers,
)
@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
reason="Test requires a CUDA or XPU device.",
)
def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
)
model.eval()
output = model(self.input)[0].cpu()
@@ -290,37 +291,41 @@ class TestGroupOffload:
output_with_group_offloading5, mem5 = run_forward(model)
# Precision assertions - offloading should not impact the output
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
# Memory assertions - offloading should reduce memory usage
assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
logger = get_logger("diffusers.models.modeling_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
self.model.to(torch_device)
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"):
logger = get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
pipe.to(torch_device)
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
original_is_available = torch_accelerator_module.is_available
torch_accelerator_module.is_available = lambda: False
with pytest.raises(ValueError):
with self.assertRaises(ValueError):
self.model.enable_group_offload(
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
)
@@ -328,31 +333,31 @@ class TestGroupOffload:
def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with pytest.raises(ValueError, match="does not support group offloading"):
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device(torch_device))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
pipe.enable_model_cpu_offload()
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
pipe.enable_sequential_cpu_offload()
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_model_cpu_offload()
with pytest.raises(ValueError, match="Cannot apply group offloading"):
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_sequential_cpu_offload()
with pytest.raises(ValueError, match="Cannot apply group offloading"):
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
@@ -371,12 +376,12 @@ class TestGroupOffload:
context = contextlib.nullcontext()
if compare_versions("diffusers", "<=", "0.33.0"):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
with context:
model(self.input)
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
@@ -402,14 +407,14 @@ class TestGroupOffload:
out_ref = model_ref(x)
out = model(x)
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
num_repeats = 2
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
@@ -423,7 +428,9 @@ class TestGroupOffload:
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
@@ -445,7 +452,9 @@ class TestGroupOffload:
out_ref = model_ref(x).sample
out = model(x).sample
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
@@ -466,11 +475,12 @@ class TestGroupOffload:
for i in range(2):
out_ref = model_ref(x)
out = model(x)
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for model with standalone layers."
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
)
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@parameterized.expand([("block_level",), ("leaf_level",)])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
@@ -491,8 +501,9 @@ class TestGroupOffload:
out_ref = model_ref(x).sample
out = model(x).sample
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match for standalone Conv layers with {offload_type}."
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
)
def test_multiple_invocations_with_vae_like_model(self):
@@ -515,7 +526,7 @@ class TestGroupOffload:
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
@@ -536,8 +547,9 @@ class TestGroupOffload:
for i in range(2):
out_ref = model_ref(x)
out = model(x)
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for nested parameters."
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
)
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@@ -590,7 +602,7 @@ class DummyModelWithConditionalModules(ModelMixin):
return x
class TestConditionalModuleGroupOffload(TestGroupOffload):
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
"""Tests for conditionally-executed modules under group offloading with streams.
Regression tests for the case where a module is not executed during the first forward pass
@@ -608,10 +620,10 @@ class TestConditionalModuleGroupOffload(TestGroupOffload):
num_layers=self.num_layers,
)
@pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
@pytest.mark.skipif(
@parameterized.expand([("leaf_level",), ("block_level",)])
@unittest.skipIf(
torch.device(torch_device).type not in ["cuda", "xpu"],
reason="Test requires a CUDA or XPU device.",
"Test requires a CUDA or XPU device.",
)
def test_conditional_modules_with_stream(self, offload_type: str):
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
@@ -658,20 +670,23 @@ class TestConditionalModuleGroupOffload(TestGroupOffload):
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
out_ref_no_opt = model_ref(x, optional_input=None)
out_no_opt = model(x, optional_input=None)
assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
self.assertTrue(
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
)
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
out_ref_with_opt = model_ref(x, optional_input=optional_input)
out_with_opt = model(x, optional_input=optional_input)
assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
self.assertTrue(
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
)
# Third pass again without optional_input — verify stable behavior.
out_ref_no_opt2 = model_ref(x, optional_input=None)
out_no_opt2 = model(x, optional_input=None)
assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
self.assertTrue(
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,49 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import warnings
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int]:
return (16, 16)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 16)
def prepare_dummy_input(self, height=4, width=4):
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 7
height = width = 4
sequence_length = 8
vae_scale_factor = 4
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -70,89 +104,57 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
encoder_hidden_states_mask[:, 2:] = 0
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
def test_non_contiguous_attention_mask(self):
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -160,95 +162,85 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with torch.no_grad():
output = model(**inputs_with_deprecated)
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 7
text_seq_len = 8
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
encoder_hidden_states_mask[0, 5:] = 0
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
]
]
@@ -262,37 +254,113 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
additional_t_cond=addition_t_cond,
)
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
assert output.sample.shape[1] == hidden_states.shape[1]
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -300,19 +368,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_no_mask_2 = model(**inputs_no_mask)
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -320,21 +384,18 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_all_ones_2 = model(**inputs_all_ones)
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
mask_with_padding[:, 4:] = 0
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -342,8 +403,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_with_padding_2 = model(**inputs_with_padding)
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""