mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-20 15:38:11 +08:00
Compare commits
8 Commits
type-hint-
...
make-tiny-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1d75ab7e35 | ||
|
|
b086b6da0a | ||
|
|
28c7516229 | ||
|
|
53b9b56059 | ||
|
|
65579667e9 | ||
|
|
cda1c36eeb | ||
|
|
f634485333 | ||
|
|
7820980959 |
@@ -1,3 +1,4 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -12,84 +13,49 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
def output_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
sequence_length = 7
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -104,57 +70,89 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0
|
||||
encoder_hidden_states_mask2[:, 3:] = 1
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
@@ -162,85 +160,95 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
"use_layer3d_rope": True,
|
||||
"use_additional_t_cond": True,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 8
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
]
|
||||
]
|
||||
|
||||
@@ -254,113 +262,37 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -368,15 +300,19 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -384,18 +320,21 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -403,15 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -31,41 +31,7 @@ from diffusers.modular_pipelines import (
|
||||
WanModularPipeline,
|
||||
)
|
||||
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
TINY_MODEL_CODE = (
|
||||
"import torch\n"
|
||||
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||
"from diffusers.configuration_utils import register_to_config\n"
|
||||
"\n"
|
||||
"class TinyModel(ModelMixin, ConfigMixin):\n"
|
||||
" @register_to_config\n"
|
||||
" def __init__(self, hidden_size=4):\n"
|
||||
" super().__init__()\n"
|
||||
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||
"\n"
|
||||
" def forward(self, x):\n"
|
||||
" return self.linear(x)\n"
|
||||
)
|
||||
|
||||
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||
f.write(TINY_MODEL_CODE)
|
||||
|
||||
config = {
|
||||
"_class_name": "TinyModel",
|
||||
"_diffusers_version": "0.0.0",
|
||||
"auto_map": {"AutoModel": "modeling.TinyModel"},
|
||||
"hidden_size": 4,
|
||||
}
|
||||
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
torch.save(
|
||||
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
|
||||
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
|
||||
)
|
||||
from ..testing_utils import nightly, require_torch, slow
|
||||
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
@@ -375,81 +341,6 @@ class TestModularCustomBlocks:
|
||||
loaded_pipe.update_components(custom_model=custom_model)
|
||||
assert getattr(loaded_pipe, "custom_model", None) is not None
|
||||
|
||||
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
|
||||
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DtypeTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DtypeTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.dtype == torch.float16
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_automodel_type_hint_preserves_device(self, tmp_path):
|
||||
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DeviceTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DeviceTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(device_map=torch_device, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.device.type == torch_device
|
||||
|
||||
def test_custom_block_loads_from_hub(self):
|
||||
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
|
||||
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
|
||||
|
||||
210
utils/make_tiny_model.py
Normal file
210
utils/make_tiny_model.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# /// script
|
||||
# requires-python = ">=3.10"
|
||||
# dependencies = [
|
||||
# "diffusers",
|
||||
# "torch",
|
||||
# "huggingface_hub",
|
||||
# "accelerate",
|
||||
# "transformers",
|
||||
# "sentencepiece",
|
||||
# "protobuf",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility script to create tiny versions of diffusers models by reducing layer counts.
|
||||
|
||||
Can be run locally or submitted as an HF Job via `--launch`.
|
||||
|
||||
Usage:
|
||||
# Run locally
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> [--subfolder transformer] [--num_layers 2]
|
||||
|
||||
# Push to Hub
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --push_to_hub --token $HF_TOKEN
|
||||
|
||||
# Submit as an HF Job
|
||||
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --launch [--flavor cpu-basic]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
LAYER_PARAM_PATTERN = re.compile(r"^(num_.*layers?|n_layers|n_refiner_layers)$")
|
||||
|
||||
DIM_PARAM_PATTERNS = {
|
||||
re.compile(r"^num_attention_heads$"): 2,
|
||||
re.compile(r"^num_.*attention_heads$"): 2,
|
||||
re.compile(r"^num_key_value_heads$"): 2,
|
||||
re.compile(r"^num_kv_heads$"): 1,
|
||||
re.compile(r"^n_heads$"): 2,
|
||||
re.compile(r"^n_kv_heads$"): 2,
|
||||
re.compile(r"^attention_head_dim$"): 8,
|
||||
re.compile(r"^.*attention_head_dim$"): 4,
|
||||
re.compile(r"^cross_attention_dim.*$"): 8,
|
||||
re.compile(r"^joint_attention_dim$"): 32,
|
||||
re.compile(r"^pooled_projection_dim$"): 32,
|
||||
re.compile(r"^caption_projection_dim$"): 32,
|
||||
re.compile(r"^caption_channels$"): 8,
|
||||
re.compile(r"^cap_feat_dim$"): 16,
|
||||
re.compile(r"^hidden_size$"): 16,
|
||||
re.compile(r"^dim$"): 16,
|
||||
re.compile(r"^.*embed_dim$"): 16,
|
||||
re.compile(r"^.*embed_.*dim$"): 16,
|
||||
re.compile(r"^text_dim$"): 16,
|
||||
re.compile(r"^time_embed_dim$"): 4,
|
||||
re.compile(r"^ffn_dim$"): 32,
|
||||
re.compile(r"^intermediate_size$"): 32,
|
||||
re.compile(r"^sample_size$"): 32,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Create a tiny version of a diffusers model.")
|
||||
parser.add_argument("--model_repo_id", type=str, required=True, help="HuggingFace repo ID of the source model.")
|
||||
parser.add_argument(
|
||||
"--output_repo_id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace repo ID or local path to save the tiny model to.",
|
||||
)
|
||||
parser.add_argument("--subfolder", type=str, default=None, help="Subfolder within the model repo.")
|
||||
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers to use for the tiny model.")
|
||||
parser.add_argument(
|
||||
"--shrink_dims",
|
||||
action="store_true",
|
||||
help="Also reduce dimension parameters (attention heads, hidden size, embedding dims, etc.).",
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Push the tiny model to the HuggingFace Hub.")
|
||||
parser.add_argument(
|
||||
"--token", type=str, default=None, help="HuggingFace token. Defaults to $HF_TOKEN env var if not provided."
|
||||
)
|
||||
|
||||
launch_group = parser.add_argument_group("HF Jobs launch options")
|
||||
launch_group.add_argument("--launch", action="store_true", help="Submit as an HF Job instead of running locally.")
|
||||
launch_group.add_argument("--flavor", type=str, default="cpu-basic", help="HF Jobs hardware flavor.")
|
||||
launch_group.add_argument("--timeout", type=str, default="30m", help="HF Jobs timeout.")
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.token is None:
|
||||
args.token = os.environ.get("HF_TOKEN")
|
||||
return args
|
||||
|
||||
|
||||
def launch_job(args):
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
script_args = [
|
||||
"--model_repo_id",
|
||||
args.model_repo_id,
|
||||
"--output_repo_id",
|
||||
args.output_repo_id,
|
||||
"--num_layers",
|
||||
str(args.num_layers),
|
||||
]
|
||||
if args.subfolder:
|
||||
script_args.extend(["--subfolder", args.subfolder])
|
||||
if args.shrink_dims:
|
||||
script_args.append("--shrink_dims")
|
||||
if args.push_to_hub:
|
||||
script_args.append("--push_to_hub")
|
||||
|
||||
job = run_uv_job(
|
||||
__file__,
|
||||
script_args=script_args,
|
||||
flavor=args.flavor,
|
||||
timeout=args.timeout,
|
||||
secrets={"HF_TOKEN": args.token} if args.token else {},
|
||||
)
|
||||
print(f"Job submitted: {job.url}")
|
||||
print(f"Job ID: {job.id}")
|
||||
return job
|
||||
|
||||
|
||||
def make_tiny_model(
|
||||
model_repo_id, output_repo_id, subfolder=None, num_layers=2, shrink_dims=False, push_to_hub=False, token=None
|
||||
):
|
||||
from diffusers import AutoModel
|
||||
|
||||
config_kwargs = {}
|
||||
if token:
|
||||
config_kwargs["token"] = token
|
||||
|
||||
config = AutoModel.load_config(model_repo_id, subfolder=subfolder, **config_kwargs)
|
||||
|
||||
modified_keys = {}
|
||||
for key, value in config.items():
|
||||
if LAYER_PARAM_PATTERN.match(key) and isinstance(value, int) and value > num_layers:
|
||||
modified_keys[key] = (value, num_layers)
|
||||
config[key] = num_layers
|
||||
|
||||
if shrink_dims:
|
||||
for key, value in config.items():
|
||||
if not isinstance(value, int) or key.startswith("_"):
|
||||
continue
|
||||
for pattern, tiny_value in DIM_PARAM_PATTERNS.items():
|
||||
if pattern.match(key) and value > tiny_value:
|
||||
modified_keys[key] = (value, tiny_value)
|
||||
config[key] = tiny_value
|
||||
break
|
||||
|
||||
if not modified_keys:
|
||||
print("WARNING: No config parameters were modified.")
|
||||
print(f"Config keys: {[k for k in config if not k.startswith('_')]}")
|
||||
return
|
||||
|
||||
print("Modified config parameters:")
|
||||
for key, (old, new) in modified_keys.items():
|
||||
print(f" {key}: {old} -> {new}")
|
||||
|
||||
model = AutoModel.from_config(config)
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
print(f"Tiny model created with {total_params:,} parameters.")
|
||||
|
||||
save_kwargs = {}
|
||||
if token:
|
||||
save_kwargs["token"] = token
|
||||
if push_to_hub:
|
||||
save_kwargs["repo_id"] = output_repo_id
|
||||
model.save_pretrained(output_repo_id, push_to_hub=push_to_hub, **save_kwargs)
|
||||
if push_to_hub:
|
||||
print(f"Model pushed to https://huggingface.co/{output_repo_id}")
|
||||
else:
|
||||
print(f"Model saved to {output_repo_id}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.launch:
|
||||
launch_job(args)
|
||||
else:
|
||||
make_tiny_model(
|
||||
model_repo_id=args.model_repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
subfolder=args.subfolder,
|
||||
num_layers=args.num_layers,
|
||||
shrink_dims=args.shrink_dims,
|
||||
push_to_hub=args.push_to_hub,
|
||||
token=args.token,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user