Compare commits

..

2 Commits

Author SHA1 Message Date
YiYi Xu
ebd06f9b11 [Modular] loader related (#13025)
* tag loader_id from Automodel

* style

* load_components by default only load components that are not already loaded

* by default, skip loading the componeneets does not have the repo id
2026-02-03 05:34:33 -10:00
songkey
b712042da1 [Flux2] Fix LoRA loading for Flux2 Klein by adaptively enumerating transformer blocks (#13030)
* Resolve Flux2 Klein 4B/9B LoRA loading errors

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-02 20:36:19 +05:30
7 changed files with 125 additions and 218 deletions

View File

@@ -2321,8 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
num_double_layers = 8
num_single_layers = 48
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
if key.startswith("single_blocks."):
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
elif key.startswith("double_blocks."):
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
lora_keys = ("lora_A", "lora_B")
attn_types = ("img_attn", "txt_attn")

View File

@@ -18,7 +18,7 @@ from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)
model._diffusers_load_id = load_id
return model

View File

@@ -2143,6 +2143,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
name
for name in self._component_specs.keys()
if self._component_specs[name].default_creation_method == "from_pretrained"
and self._component_specs[name].pretrained_model_name_or_path is not None
and getattr(self, name, None) is None
]
elif isinstance(names, str):
names = [names]

View File

@@ -15,7 +15,7 @@
import inspect
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Type, Union
import PIL.Image
@@ -23,7 +23,7 @@ import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
if is_torch_available():
@@ -186,7 +186,7 @@ class ComponentSpec:
"""
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
return DIFFUSERS_LOAD_ID_FIELDS.copy()
@property
def load_id(self) -> str:
@@ -198,7 +198,7 @@ class ComponentSpec:
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
return "|".join(parts)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:

View File

@@ -23,6 +23,7 @@ from .constants import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,

View File

@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
DIFFUSERS_LOAD_ID_FIELDS = [
"pretrained_model_name_or_path",
"subfolder",
"variant",
"revision",
]

View File

@@ -13,87 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 16)
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
# Skip setting testing with default: AttnProcessor
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8 # Must be divisible by 2 for context parallel tests
sequence_length = 7
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -108,12 +70,29 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
@@ -125,56 +104,55 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
)
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
assert isinstance(rope_text_seq_len, int)
self.assertIsInstance(rope_text_seq_len, int)
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2 # Only 2 True values
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 5 tokens are valid (seq_len=8)
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
# Max valid position is 7 (last token), so per_sample_len should be 8
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5 # 5 True values
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self):
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0, 0])"""
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False, False] (seq_len=8)
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -182,22 +160,21 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
@@ -209,24 +186,18 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
# Test that deprecation warning is raised
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)
# Verify a FutureWarning was raised
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the warning message mentions the deprecation
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
# Verify the model still works correctly despite the deprecation
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
@@ -237,7 +208,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
"num_attention_heads": 3,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
@@ -249,11 +220,11 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 8
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4
@@ -291,104 +262,24 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
additional_t_cond=addition_t_cond,
)
assert output.sample.shape[1] == hidden_states.shape[1]
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8 # Must be divisible by 2 for context parallel tests
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
@@ -409,13 +300,13 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_no_mask_2 = model(**inputs_no_mask)
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
@@ -429,8 +320,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_all_ones_2 = model(**inputs_all_ones)
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
@@ -451,16 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_with_padding_2 = model(**inputs_with_padding)
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Verify that outputs are different (mask should affect results)
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))