Compare commits

..

4 Commits

Author SHA1 Message Date
DN6
d90a4dfe57 update 2026-03-19 13:16:40 +05:30
DN6
5219182752 update 2026-03-19 13:08:43 +05:30
DN6
189491a4f2 update 2026-03-19 13:05:08 +05:30
Dhruv Nair
11a3284cee [CI] Qwen Image Model Test Refactor (#13069)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-17 16:44:04 +05:30
3 changed files with 329 additions and 211 deletions

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ..utils import get_logger, is_accelerate_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,41 +35,6 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a TorchAO quantized tensor subclass."""
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None:
"""Update internal tensor data of a TorchAO parameter in-place from source.
Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass``
returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost.
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -192,16 +157,9 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_update_torchao_tensor_in_place(tensor, moved)
else:
tensor.data = moved
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -287,35 +245,18 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
_update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,49 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import warnings
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int]:
return (16, 16)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 16)
def prepare_dummy_input(self, height=4, width=4):
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 7
height = width = 4
sequence_length = 8
vae_scale_factor = 4
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -70,89 +104,57 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
encoder_hidden_states_mask[:, 2:] = 0
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
def test_non_contiguous_attention_mask(self):
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -160,95 +162,85 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with torch.no_grad():
output = model(**inputs_with_deprecated)
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 7
text_seq_len = 8
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
encoder_hidden_states_mask[0, 5:] = 0
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
]
]
@@ -262,37 +254,113 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
additional_t_cond=addition_t_cond,
)
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
assert output.sample.shape[1] == hidden_states.shape[1]
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -300,19 +368,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_no_mask_2 = model(**inputs_no_mask)
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -320,21 +384,18 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_all_ones_2 = model(**inputs_all_ones)
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
mask_with_padding[:, 4:] = 0
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -342,8 +403,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_with_padding_2 = model(**inputs_with_padding)
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""

View File

@@ -31,7 +31,41 @@ from diffusers.modular_pipelines import (
WanModularPipeline,
)
from ..testing_utils import nightly, require_torch, slow
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
def _create_tiny_model_dir(model_dir):
TINY_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class TinyModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=4):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(TINY_MODEL_CODE)
config = {
"_class_name": "TinyModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.TinyModel"},
"hidden_size": 4,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save(
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
)
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -341,6 +375,81 @@ class TestModularCustomBlocks:
loaded_pipe.update_components(custom_model=custom_model)
assert getattr(loaded_pipe, "custom_model", None) is not None
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
from diffusers import AutoModel
model_dir = str(tmp_path / "model")
os.makedirs(model_dir)
_create_tiny_model_dir(model_dir)
class DtypeTestBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True)]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("output", type_hint=str)]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.output = "test"
self.set_block_state(state, block_state)
return components, state
block = DtypeTestBlock()
pipe = block.init_pipeline()
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
assert pipe.model.dtype == torch.float16
@require_torch_accelerator
def test_automodel_type_hint_preserves_device(self, tmp_path):
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
from diffusers import AutoModel
model_dir = str(tmp_path / "model")
os.makedirs(model_dir)
_create_tiny_model_dir(model_dir)
class DeviceTestBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True)]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("output", type_hint=str)]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.output = "test"
self.set_block_state(state, block_state)
return components, state
block = DeviceTestBlock()
pipe = block.init_pipeline()
pipe.load_components(device_map=torch_device, trust_remote_code=True)
assert pipe.model.device.type == torch_device
def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)