Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30
3 changed files with 210 additions and 257 deletions

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,41 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a TorchAO quantized tensor subclass."""
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _update_torchao_tensor_in_place(param: torch.Tensor, source: torch.Tensor) -> None:
"""Update internal tensor data of a TorchAO parameter in-place from source.
Must operate on the parameter/buffer object directly (not ``param.data``) because ``_make_wrapper_subclass``
returns a fresh wrapper from ``.data`` each time, so attribute mutations on ``.data`` are lost.
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -157,9 +192,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_update_torchao_tensor_in_place(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -245,18 +287,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_update_torchao_tensor_in_place(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_update_torchao_tensor_in_place(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_update_torchao_tensor_in_place(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -2683,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,84 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 16)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
sequence_length = 7
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -104,57 +70,89 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -162,85 +160,95 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 8
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
]
]
@@ -254,113 +262,37 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
additional_t_cond=addition_t_cond,
)
assert output.sample.shape[1] == hidden_states.shape[1]
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -368,15 +300,19 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_no_mask_2 = model(**inputs_no_mask)
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -384,18 +320,21 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_all_ones_2 = model(**inputs_all_ones)
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -403,15 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_with_padding_2 = model(**inputs_with_padding)
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))