Compare commits

..

8 Commits

Author SHA1 Message Date
DN6
1d75ab7e35 update 2026-03-17 23:23:01 +05:30
DN6
b086b6da0a update 2026-03-17 15:40:27 +05:30
DN6
28c7516229 update 2026-03-17 15:38:03 +05:30
DN6
53b9b56059 update 2026-03-17 15:33:17 +05:30
DN6
65579667e9 update 2026-03-17 15:31:30 +05:30
DN6
cda1c36eeb update 2026-03-17 13:41:49 +05:30
DN6
f634485333 Merge branch 'main' into make-tiny-model 2026-03-17 13:37:02 +05:30
DN6
7820980959 update 2026-03-17 13:36:49 +05:30
9 changed files with 399 additions and 331 deletions

View File

@@ -12,7 +12,6 @@ from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
@@ -25,10 +24,7 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
@@ -96,22 +92,12 @@ def main(args):
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720:
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -196,8 +182,8 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": in_channels,
"out_channels": out_channels,
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
@@ -249,12 +235,9 @@ def main(args):
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
if args.vae_type == "ltx2":
vae_path = args.vae_path or "Lightricks/LTX-2"
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -331,23 +314,7 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument(
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
@@ -342,6 +343,7 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -720,7 +720,6 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,

View File

@@ -35,8 +35,6 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output

View File

@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: DPMSolverMultistepScheduler,
):
@@ -223,19 +223,8 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -996,21 +985,14 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
@@ -213,19 +213,8 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -698,18 +687,14 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
if isinstance(self.vae, AutoencoderKLLTX2Video):
_latents_mean = self.vae.latents_mean
_latents_std = self.vae.latents_std
elif isinstance(self.vae, AutoencoderKLWan):
_latents_mean = torch.tensor(self.vae.config.latents_mean)
_latents_std = torch.tensor(self.vae.config.latents_std)
else:
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
image_latents.device, image_latents.dtype
)
image_latents = (image_latents - latents_mean) * latents_std
latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1049,21 +1034,14 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,84 +13,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import unittest
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 16)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
sequence_length = 7
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -104,57 +70,89 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -162,85 +160,95 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_txt_seq_lens_deprecation(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with torch.no_grad():
output = model(**inputs_with_deprecated)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 8
text_seq_len = 7
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
]
]
@@ -254,113 +262,37 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
additional_t_cond=addition_t_cond,
)
assert output.sample.shape[1] == hidden_states.shape[1]
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_torch_compile_with_and_without_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -368,15 +300,19 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_no_mask_2 = model(**inputs_no_mask)
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -384,18 +320,21 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_all_ones_2 = model(**inputs_all_ones)
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -403,15 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
):
output_with_padding_2 = model(**inputs_with_padding)
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))

View File

@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
# Override to set a more lenient max diff threshold.
def test_save_load_float16(self):
pass
super().test_save_load_float16(expected_max_diff=0.03)
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):

210
utils/make_tiny_model.py Normal file
View File

@@ -0,0 +1,210 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "diffusers",
# "torch",
# "huggingface_hub",
# "accelerate",
# "transformers",
# "sentencepiece",
# "protobuf",
# ]
# ///
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to create tiny versions of diffusers models by reducing layer counts.
Can be run locally or submitted as an HF Job via `--launch`.
Usage:
# Run locally
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> [--subfolder transformer] [--num_layers 2]
# Push to Hub
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --push_to_hub --token $HF_TOKEN
# Submit as an HF Job
python make_tiny_model.py --model_repo_id <model_repo_id> --output_repo_id <output_repo_id> --launch [--flavor cpu-basic]
"""
import argparse
import os
import re
LAYER_PARAM_PATTERN = re.compile(r"^(num_.*layers?|n_layers|n_refiner_layers)$")
DIM_PARAM_PATTERNS = {
re.compile(r"^num_attention_heads$"): 2,
re.compile(r"^num_.*attention_heads$"): 2,
re.compile(r"^num_key_value_heads$"): 2,
re.compile(r"^num_kv_heads$"): 1,
re.compile(r"^n_heads$"): 2,
re.compile(r"^n_kv_heads$"): 2,
re.compile(r"^attention_head_dim$"): 8,
re.compile(r"^.*attention_head_dim$"): 4,
re.compile(r"^cross_attention_dim.*$"): 8,
re.compile(r"^joint_attention_dim$"): 32,
re.compile(r"^pooled_projection_dim$"): 32,
re.compile(r"^caption_projection_dim$"): 32,
re.compile(r"^caption_channels$"): 8,
re.compile(r"^cap_feat_dim$"): 16,
re.compile(r"^hidden_size$"): 16,
re.compile(r"^dim$"): 16,
re.compile(r"^.*embed_dim$"): 16,
re.compile(r"^.*embed_.*dim$"): 16,
re.compile(r"^text_dim$"): 16,
re.compile(r"^time_embed_dim$"): 4,
re.compile(r"^ffn_dim$"): 32,
re.compile(r"^intermediate_size$"): 32,
re.compile(r"^sample_size$"): 32,
}
def parse_args():
parser = argparse.ArgumentParser(description="Create a tiny version of a diffusers model.")
parser.add_argument("--model_repo_id", type=str, required=True, help="HuggingFace repo ID of the source model.")
parser.add_argument(
"--output_repo_id",
type=str,
required=True,
help="HuggingFace repo ID or local path to save the tiny model to.",
)
parser.add_argument("--subfolder", type=str, default=None, help="Subfolder within the model repo.")
parser.add_argument("--num_layers", type=int, default=2, help="Number of layers to use for the tiny model.")
parser.add_argument(
"--shrink_dims",
action="store_true",
help="Also reduce dimension parameters (attention heads, hidden size, embedding dims, etc.).",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push the tiny model to the HuggingFace Hub.")
parser.add_argument(
"--token", type=str, default=None, help="HuggingFace token. Defaults to $HF_TOKEN env var if not provided."
)
launch_group = parser.add_argument_group("HF Jobs launch options")
launch_group.add_argument("--launch", action="store_true", help="Submit as an HF Job instead of running locally.")
launch_group.add_argument("--flavor", type=str, default="cpu-basic", help="HF Jobs hardware flavor.")
launch_group.add_argument("--timeout", type=str, default="30m", help="HF Jobs timeout.")
args = parser.parse_args()
if args.token is None:
args.token = os.environ.get("HF_TOKEN")
return args
def launch_job(args):
from huggingface_hub import run_uv_job
script_args = [
"--model_repo_id",
args.model_repo_id,
"--output_repo_id",
args.output_repo_id,
"--num_layers",
str(args.num_layers),
]
if args.subfolder:
script_args.extend(["--subfolder", args.subfolder])
if args.shrink_dims:
script_args.append("--shrink_dims")
if args.push_to_hub:
script_args.append("--push_to_hub")
job = run_uv_job(
__file__,
script_args=script_args,
flavor=args.flavor,
timeout=args.timeout,
secrets={"HF_TOKEN": args.token} if args.token else {},
)
print(f"Job submitted: {job.url}")
print(f"Job ID: {job.id}")
return job
def make_tiny_model(
model_repo_id, output_repo_id, subfolder=None, num_layers=2, shrink_dims=False, push_to_hub=False, token=None
):
from diffusers import AutoModel
config_kwargs = {}
if token:
config_kwargs["token"] = token
config = AutoModel.load_config(model_repo_id, subfolder=subfolder, **config_kwargs)
modified_keys = {}
for key, value in config.items():
if LAYER_PARAM_PATTERN.match(key) and isinstance(value, int) and value > num_layers:
modified_keys[key] = (value, num_layers)
config[key] = num_layers
if shrink_dims:
for key, value in config.items():
if not isinstance(value, int) or key.startswith("_"):
continue
for pattern, tiny_value in DIM_PARAM_PATTERNS.items():
if pattern.match(key) and value > tiny_value:
modified_keys[key] = (value, tiny_value)
config[key] = tiny_value
break
if not modified_keys:
print("WARNING: No config parameters were modified.")
print(f"Config keys: {[k for k in config if not k.startswith('_')]}")
return
print("Modified config parameters:")
for key, (old, new) in modified_keys.items():
print(f" {key}: {old} -> {new}")
model = AutoModel.from_config(config)
total_params = sum(p.numel() for p in model.parameters())
print(f"Tiny model created with {total_params:,} parameters.")
save_kwargs = {}
if token:
save_kwargs["token"] = token
if push_to_hub:
save_kwargs["repo_id"] = output_repo_id
model.save_pretrained(output_repo_id, push_to_hub=push_to_hub, **save_kwargs)
if push_to_hub:
print(f"Model pushed to https://huggingface.co/{output_repo_id}")
else:
print(f"Model saved to {output_repo_id}")
def main():
args = parse_args()
if args.launch:
launch_job(args)
else:
make_tiny_model(
model_repo_id=args.model_repo_id,
output_repo_id=args.output_repo_id,
subfolder=args.subfolder,
num_layers=args.num_layers,
shrink_dims=args.shrink_dims,
push_to_hub=args.push_to_hub,
token=args.token,
)
if __name__ == "__main__":
main()