Compare commits

..

8 Commits

Author SHA1 Message Date
DN6
1d75ab7e35 update 2026-03-17 23:23:01 +05:30
DN6
b086b6da0a update 2026-03-17 15:40:27 +05:30
DN6
28c7516229 update 2026-03-17 15:38:03 +05:30
DN6
53b9b56059 update 2026-03-17 15:33:17 +05:30
DN6
65579667e9 update 2026-03-17 15:31:30 +05:30
DN6
cda1c36eeb update 2026-03-17 13:41:49 +05:30
DN6
f634485333 Merge branch 'main' into make-tiny-model 2026-03-17 13:37:02 +05:30
DN6
7820980959 update 2026-03-17 13:36:49 +05:30
3 changed files with 354 additions and 321 deletions

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,84 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 16)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
sequence_length = 7
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -104,57 +70,89 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -162,85 +160,95 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 8
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
]
]
@@ -254,113 +262,37 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
additional_t_cond=addition_t_cond,
)
assert output.sample.shape[1] == hidden_states.shape[1]
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -368,15 +300,19 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_no_mask_2 = model(**inputs_no_mask)
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -384,18 +320,21 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_all_ones_2 = model(**inputs_all_ones)
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -403,15 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_with_padding_2 = model(**inputs_with_padding)
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))

View File

@@ -31,41 +31,7 @@ from diffusers.modular_pipelines import (
WanModularPipeline,
)
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
def _create_tiny_model_dir(model_dir):
TINY_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class TinyModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=4):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(TINY_MODEL_CODE)
config = {
"_class_name": "TinyModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.TinyModel"},
"hidden_size": 4,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save(
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
)
from ..testing_utils import nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -375,81 +341,6 @@ class TestModularCustomBlocks:
loaded_pipe.update_components(custom_model=custom_model)
assert getattr(loaded_pipe, "custom_model", None) is not None
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
from diffusers import AutoModel
model_dir = str(tmp_path / "model")
os.makedirs(model_dir)
_create_tiny_model_dir(model_dir)
class DtypeTestBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True)]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("output", type_hint=str)]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.output = "test"
self.set_block_state(state, block_state)
return components, state
block = DtypeTestBlock()
pipe = block.init_pipeline()
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
assert pipe.model.dtype == torch.float16
@require_torch_accelerator
def test_automodel_type_hint_preserves_device(self, tmp_path):
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
from diffusers import AutoModel
model_dir = str(tmp_path / "model")
os.makedirs(model_dir)
_create_tiny_model_dir(model_dir)
class DeviceTestBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True)]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("output", type_hint=str)]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.output = "test"
self.set_block_state(state, block_state)
return components, state
block = DeviceTestBlock()
pipe = block.init_pipeline()
pipe.load_components(device_map=torch_device, trust_remote_code=True)
assert pipe.model.device.type == torch_device
def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)

210
utils/make_tiny_model.py Normal file
View File

@@ -0,0 +1,210 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "diffusers",
# "torch",
# "huggingface_hub",
# "accelerate",
# "transformers",
# "sentencepiece",
# "protobuf",
# ]
# ///
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to create tiny versions of diffusers models by reducing layer counts.
Can be run locally or submitted as an HF Job via `--launch`.
Usage:
# Run locally
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> [--subfolder transformer] [--num_layers 2]
# Push to Hub
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --push_to_hub --token $HF_TOKEN
# Submit as an HF Job
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --launch [--flavor cpu-basic]
"""
import argparse
import os
import re
LAYER_PARAM_PATTERN = re.compile(r"^(num_.*layers?|n_layers|n_refiner_layers)$")
DIM_PARAM_PATTERNS = {
re.compile(r"^num_attention_heads$"): 2,
re.compile(r"^num_.*attention_heads$"): 2,
re.compile(r"^num_key_value_heads$"): 2,
re.compile(r"^num_kv_heads$"): 1,
re.compile(r"^n_heads$"): 2,
re.compile(r"^n_kv_heads$"): 2,
re.compile(r"^attention_head_dim$"): 8,
re.compile(r"^.*attention_head_dim$"): 4,
re.compile(r"^cross_attention_dim.*$"): 8,
re.compile(r"^joint_attention_dim$"): 32,
re.compile(r"^pooled_projection_dim$"): 32,
re.compile(r"^caption_projection_dim$"): 32,
re.compile(r"^caption_channels$"): 8,
re.compile(r"^cap_feat_dim$"): 16,
re.compile(r"^hidden_size$"): 16,
re.compile(r"^dim$"): 16,
re.compile(r"^.*embed_dim$"): 16,
re.compile(r"^.*embed_.*dim$"): 16,
re.compile(r"^text_dim$"): 16,
re.compile(r"^time_embed_dim$"): 4,
re.compile(r"^ffn_dim$"): 32,
re.compile(r"^intermediate_size$"): 32,
re.compile(r"^sample_size$"): 32,
}
def parse_args():
parser = argparse.ArgumentParser(description="Create a tiny version of a diffusers model.")
parser.add_argument("--model_repo_id", type=str, required=True, help="HuggingFace repo ID of the source model.")
parser.add_argument(
"--output_repo_id",
type=str,
required=True,
help="HuggingFace repo ID or local path to save the tiny model to.",
)
parser.add_argument("--subfolder", type=str, default=None, help="Subfolder within the model repo.")
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers to use for the tiny model.")
parser.add_argument(
"--shrink_dims",
action="store_true",
help="Also reduce dimension parameters (attention heads, hidden size, embedding dims, etc.).",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push the tiny model to the HuggingFace Hub.")
parser.add_argument(
"--token", type=str, default=None, help="HuggingFace token. Defaults to $HF_TOKEN env var if not provided."
)
launch_group = parser.add_argument_group("HF Jobs launch options")
launch_group.add_argument("--launch", action="store_true", help="Submit as an HF Job instead of running locally.")
launch_group.add_argument("--flavor", type=str, default="cpu-basic", help="HF Jobs hardware flavor.")
launch_group.add_argument("--timeout", type=str, default="30m", help="HF Jobs timeout.")
args = parser.parse_args()
if args.token is None:
args.token = os.environ.get("HF_TOKEN")
return args
def launch_job(args):
from huggingface_hub import run_uv_job
script_args = [
"--model_repo_id",
args.model_repo_id,
"--output_repo_id",
args.output_repo_id,
"--num_layers",
str(args.num_layers),
]
if args.subfolder:
script_args.extend(["--subfolder", args.subfolder])
if args.shrink_dims:
script_args.append("--shrink_dims")
if args.push_to_hub:
script_args.append("--push_to_hub")
job = run_uv_job(
__file__,
script_args=script_args,
flavor=args.flavor,
timeout=args.timeout,
secrets={"HF_TOKEN": args.token} if args.token else {},
)
print(f"Job submitted: {job.url}")
print(f"Job ID: {job.id}")
return job
def make_tiny_model(
model_repo_id, output_repo_id, subfolder=None, num_layers=2, shrink_dims=False, push_to_hub=False, token=None
):
from diffusers import AutoModel
config_kwargs = {}
if token:
config_kwargs["token"] = token
config = AutoModel.load_config(model_repo_id, subfolder=subfolder, **config_kwargs)
modified_keys = {}
for key, value in config.items():
if LAYER_PARAM_PATTERN.match(key) and isinstance(value, int) and value > num_layers:
modified_keys[key] = (value, num_layers)
config[key] = num_layers
if shrink_dims:
for key, value in config.items():
if not isinstance(value, int) or key.startswith("_"):
continue
for pattern, tiny_value in DIM_PARAM_PATTERNS.items():
if pattern.match(key) and value > tiny_value:
modified_keys[key] = (value, tiny_value)
config[key] = tiny_value
break
if not modified_keys:
print("WARNING: No config parameters were modified.")
print(f"Config keys: {[k for k in config if not k.startswith('_')]}")
return
print("Modified config parameters:")
for key, (old, new) in modified_keys.items():
print(f" {key}: {old} -> {new}")
model = AutoModel.from_config(config)
total_params = sum(p.numel() for p in model.parameters())
print(f"Tiny model created with {total_params:,} parameters.")
save_kwargs = {}
if token:
save_kwargs["token"] = token
if push_to_hub:
save_kwargs["repo_id"] = output_repo_id
model.save_pretrained(output_repo_id, push_to_hub=push_to_hub, **save_kwargs)
if push_to_hub:
print(f"Model pushed to https://huggingface.co/{output_repo_id}")
else:
print(f"Model saved to {output_repo_id}")
def main():
args = parse_args()
if args.launch:
launch_job(args)
else:
make_tiny_model(
model_repo_id=args.model_repo_id,
output_repo_id=args.output_repo_id,
subfolder=args.subfolder,
num_layers=args.num_layers,
shrink_dims=args.shrink_dims,
push_to_hub=args.push_to_hub,
token=args.token,
)
if __name__ == "__main__":
main()