Compare commits

...

8 Commits

Author SHA1 Message Date
Sayak Paul
052d5e6d5f Update attention_backends.md 2026-03-18 15:43:53 +05:30
kaixuanliu
8e4b5607ed skip invalid test case for helios pipeline (#13218)
* skip invalid test case for helio pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update skip reason

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-17 20:58:35 -10:00
Junsong Chen
c6f72ad2f6 add ltx2 vae in sana-video; (#13229)
* add ltx2 vae in sana-video;

* add ltx vae in conversion script;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* condition `vae_scale_factor_xxx` related settings on VAE types;

* make the mean/std depends on vae class;

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-17 18:09:52 -10:00
Dhruv Nair
11a3284cee [CI] Qwen Image Model Test Refactor (#13069)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-17 16:44:04 +05:30
Sayak Paul
16e7067647 [tests] fix llava kwargs in the hunyuan tests (#13275)
fix llava kwargs in the hunyuan tests
2026-03-17 10:11:47 +05:30
Dhruv Nair
d1b3555c29 [Modular] Fix dtype assignment when type hint is AutoModel (#13271)
* update

* update
2026-03-17 09:47:53 +05:30
Wang, Yi
9677859ebf fix parallelism case failure in xpu (#13270)
* fix parallelism case failure in xpu

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* updated

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-17 08:52:15 +05:30
Steven Liu
ed31974c3e [docs] updates (#13248)
* fixes

* few more links

* update zh

* fix
2026-03-16 13:24:57 -07:00
19 changed files with 371 additions and 353 deletions

View File

@@ -22,6 +22,8 @@
title: Reproducibility
- local: using-diffusers/schedulers
title: Schedulers
- local: using-diffusers/guiders
title: Guiders
- local: using-diffusers/automodel
title: AutoModel
- local: using-diffusers/other-formats
@@ -110,8 +112,6 @@
title: ModularPipeline
- local: modular_diffusers/components_manager
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
- local: modular_diffusers/mellon

View File

@@ -99,7 +99,7 @@ To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
pipe.guider = pipe.guider.new(guidance_scale=5.0)
```
Read more on Guider [here](../../modular_diffusers/guiders).
Read more on Guider [here](../../using-diffusers/guiders).

View File

@@ -30,7 +30,7 @@ HunyuanImage-2.1 comes in the following variants:
## HunyuanImage-2.1
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../../using-diffusers/guiders)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
```python
import torch

View File

@@ -338,7 +338,7 @@ guider = ClassifierFreeGuidance(guidance_scale=5.0)
pipeline.update_components(guider=guider)
```
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
See the [Guiders](../using-diffusers/guiders) guide for more details on available guiders and how to configure them.
## Splitting a pipeline into stages

View File

@@ -39,7 +39,7 @@ The Modular Diffusers docs are organized as shown below.
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
- [Guiders](../using-diffusers/guiders) shows you how to use different guidance methods in the pipeline.
## Mellon Integration

View File

@@ -35,7 +35,7 @@ The [`~ModelMixin.set_attention_backend`] method iterates through all the module
The example below demonstrates how to enable the `_flash_3_hub` implementation for FlashAttention-3 from the [`kernels`](https://github.com/huggingface/kernels) library, which allows you to instantly use optimized compute kernels from the Hub without requiring any setup.
> [!NOTE]
> FlashAttention-3 is not supported for non-Hopper architectures, in which case, use FlashAttention with `set_attention_backend("flash")`.
> For FlashAttention-3, at least Ampere GPUs is needed.
```py
import torch

View File

@@ -482,144 +482,6 @@ print(
) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
```
## torch.jit.trace
[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) can be compiled.
```py
import time
import torch
from diffusers import StableDiffusionPipeline
import functools
# torch disable grad
torch.set_grad_enabled(False)
# set variables
n_experiments = 2
unet_runs_per_experiment = 50
# load sample inputs
def generate_inputs():
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
return sample, timestep, encoder_hidden_states
pipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
unet = pipeline.unet
unet.eval()
unet.to(memory_format=torch.channels_last) # use channels_last memory format
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
# warmup
for _ in range(3):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet(*inputs)
# trace
print("tracing..")
unet_traced = torch.jit.trace(unet, inputs)
unet_traced.eval()
print("done tracing")
# warmup and optimize graph
for _ in range(5):
with torch.inference_mode():
inputs = generate_inputs()
orig_output = unet_traced(*inputs)
# benchmarking
with torch.inference_mode():
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet_traced(*inputs)
torch.cuda.synchronize()
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
for _ in range(n_experiments):
torch.cuda.synchronize()
start_time = time.time()
for _ in range(unet_runs_per_experiment):
orig_output = unet(*inputs)
torch.cuda.synchronize()
print(f"unet inference took {time.time() - start_time:.2f} seconds")
# save the model
unet_traced.save("unet_traced.pt")
```
Replace the pipeline's UNet with the traced version.
```py
import torch
from diffusers import StableDiffusionPipeline
from dataclasses import dataclass
@dataclass
class UNet2DConditionOutput:
sample: torch.Tensor
pipeline = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
# use jitted unet
unet_traced = torch.jit.load("unet_traced.pt")
# del pipeline.unet
class TracedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states):
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
return UNet2DConditionOutput(sample=sample)
pipeline.unet = TracedUNet()
with torch.inference_mode():
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
```
## Memory-efficient attention
> [!TIP]
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.
You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
```py
# pip install xformers
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_xformers_memory_efficient_attention()
```
Call [`~ModelMixin.disable_xformers_memory_efficient_attention`] to disable it.
```py
pipeline.disable_xformers_memory_efficient_attention()
```
Diffusers supports multiple memory-efficient attention backends (FlashAttention, xFormers, SageAttention, and more) through [`~ModelMixin.set_attention_backend`]. Refer to the [Attention backends](./attention_backends) guide to learn how to switch between them.

View File

@@ -23,7 +23,7 @@ pip install xformers
> [!TIP]
> The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention).
After xFormers is installed, you can use it with [`~ModelMixin.set_attention_backend`] as shown in the [Attention backends](./attention_backends) guide.
> [!WARNING]
> According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.

View File

@@ -14,6 +14,8 @@
sections:
- local: using-diffusers/schedulers
title: Load schedulers and models
- local: using-diffusers/guiders
title: Guiders
- title: Inference
isExpanded: false
@@ -80,8 +82,6 @@
title: ModularPipeline
- local: modular_diffusers/components_manager
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- title: Training
isExpanded: false

View File

@@ -12,6 +12,7 @@ from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
@@ -92,12 +96,22 @@ def main(args):
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -182,8 +196,8 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"in_channels": in_channels,
"out_channels": out_channels,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
@@ -235,9 +249,12 @@ def main(args):
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
if args.vae_type == "ltx2":
vae_path = args.vae_path or "Lightricks/LTX-2"
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -314,7 +331,23 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument(
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -309,16 +309,16 @@ class ComponentSpec:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)
from diffusers import AutoModel
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)):
kwargs.pop("torch_dtype", None)
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
@@ -332,12 +332,6 @@ class ComponentSpec:
else getattr(self.type_hint, "from_pretrained")
)
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:

View File

@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: DPMSolverMultistepScheduler,
):
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
image_latents.device, image_latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
_latents_mean = self.vae.latents_mean
_latents_std = self.vae.latents_std
elif isinstance(self.vae, AutoencoderKLWan):
_latents_mean = torch.tensor(self.vae.config.latents_mean)
_latents_std = torch.tensor(self.vae.config.latents_std)
else:
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) * latents_std
latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,9 +26,17 @@ from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
torch_device,
)
# Device configuration mapping
DEVICE_CONFIG = {
"cuda": {"backend": "nccl", "module": torch.cuda},
"xpu": {"backend": "xccl", "module": torch.xpu},
}
def _find_free_port():
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -47,12 +55,17 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
# Initialize process group
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# Set device for this process
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")
# Create model
model = model_class(**init_dict)
@@ -103,10 +116,16 @@ def _custom_mesh_worker(
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
# Set device for this process
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")
model = model_class(**init_dict)
model.to(device)
@@ -116,7 +135,7 @@ def _custom_mesh_worker(
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
torch_device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,49 +12,84 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import warnings
import torch
from diffusers import QwenImageTransformer2DModel
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return QwenImageTransformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int]:
return (16, 16)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 16)
def prepare_dummy_input(self, height=4, width=4):
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 7
height = width = 4
sequence_length = 8
vae_scale_factor = 4
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
@@ -70,89 +104,57 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
"img_shapes": img_shapes,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"joint_attention_dim": 16,
"guidance_embeds": False,
"axes_dims_rope": (8, 4, 4),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
encoder_hidden_states_mask[:, 2:] = 0
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
self.assertIsInstance(rope_text_seq_len, int)
assert isinstance(rope_text_seq_len, int)
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
self.assertIsInstance(per_sample_len, torch.Tensor)
self.assertEqual(int(per_sample_len.max().item()), 2)
# Verify mask is normalized to bool dtype
self.assertTrue(normalized_mask.dtype == torch.bool)
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
# Verify rope_text_seq_len is at least the sequence length
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
# Test 2: Verify model runs successfully with inferred values
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Different mask pattern (padding at beginning)
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
encoder_hidden_states_mask2[:, :3] = 0
encoder_hidden_states_mask2[:, 3:] = 1
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
)
# Max valid position is 6 (last token), so per_sample_len should be 7
self.assertEqual(int(per_sample_len2.max().item()), 7)
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
# Test 4: No mask provided (None case)
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
)
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(rope_text_seq_len_none, int)
self.assertIsNone(per_sample_len_none)
self.assertIsNone(normalized_mask_none)
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
assert isinstance(rope_text_seq_len_none, int)
assert per_sample_len_none is None
assert normalized_mask_none is None
def test_non_contiguous_attention_mask(self):
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
# Pattern: [True, False, True, False, True, False, False]
encoder_hidden_states_mask[:, 1] = 0
encoder_hidden_states_mask[:, 3] = 0
encoder_hidden_states_mask[:, 5:] = 0
@@ -160,95 +162,85 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], encoder_hidden_states_mask
)
self.assertEqual(int(per_sample_len.max().item()), 5)
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
self.assertIsInstance(inferred_rope_len, int)
self.assertTrue(normalized_mask.dtype == torch.bool)
assert int(per_sample_len.max().item()) == 5
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
assert isinstance(inferred_rope_len, int)
assert normalized_mask.dtype == torch.bool
inputs["encoder_hidden_states_mask"] = normalized_mask
with torch.no_grad():
output = model(**inputs)
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_txt_seq_lens_deprecation(self):
"""Test that passing txt_seq_lens raises a deprecation warning."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Prepare inputs with txt_seq_lens (deprecated parameter)
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
# Remove encoder_hidden_states_mask to use the deprecated path
inputs_with_deprecated = inputs.copy()
inputs_with_deprecated.pop("encoder_hidden_states_mask")
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
# Test that deprecation warning is raised
with self.assertWarns(FutureWarning) as warning_context:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with torch.no_grad():
output = model(**inputs_with_deprecated)
# Verify the warning message mentions the deprecation
warning_message = str(warning_context.warning)
self.assertIn("txt_seq_lens", warning_message)
self.assertIn("deprecated", warning_message)
self.assertIn("encoder_hidden_states_mask", warning_message)
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
# Verify the model still works correctly despite the deprecation
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
warning_message = str(future_warnings[0].message)
assert "txt_seq_lens" in warning_message
assert "deprecated" in warning_message
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
def test_layered_model_with_mask(self):
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
# Create layered model config
init_dict = {
"patch_size": 2,
"in_channels": 16,
"out_channels": 4,
"num_layers": 2,
"attention_head_dim": 16,
"num_attention_heads": 3,
"num_attention_heads": 4,
"joint_attention_dim": 16,
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
"use_layer3d_rope": True, # Enable layered RoPE
"use_additional_t_cond": True, # Enable additional time conditioning
"axes_dims_rope": (8, 4, 4),
"use_layer3d_rope": True,
"use_additional_t_cond": True,
}
model = self.model_class(**init_dict).to(torch_device)
# Verify the model uses QwenEmbedLayer3DRope
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
# Test single generation with layered structure
batch_size = 1
text_seq_len = 7
text_seq_len = 8
img_h, img_w = 4, 4
layers = 4
# For layered model: (layers + 1) because we have N layers + 1 combined image
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
# Create mask with some padding
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
encoder_hidden_states_mask[0, 5:] = 0
timestep = torch.tensor([1.0]).to(torch_device)
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
# Layer structure: 4 layers + 1 condition image
img_shapes = [
[
(1, img_h, img_w), # layer 0
(1, img_h, img_w), # layer 1
(1, img_h, img_w), # layer 2
(1, img_h, img_w), # layer 3
(1, img_h, img_w), # condition image (last one gets special treatment)
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
(1, img_h, img_w),
]
]
@@ -262,37 +254,113 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
additional_t_cond=addition_t_cond,
)
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
assert output.sample.shape[1] == hidden_states.shape[1]
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = QwenImageTransformer2DModel
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for QwenImage Transformer."""
def prepare_init_args_and_inputs_for_common(self):
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for QwenImage Transformer."""
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
def test_gradient_checkpointing_is_applied(self):
expected_set = {"QwenImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for QwenImage Transformer."""
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for QwenImage Transformer."""
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for QwenImage Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
sequence_length = 8
vae_scale_factor = 4
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
orig_height = height * 2 * vae_scale_factor
orig_width = width * 2 * vae_scale_factor
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_mask": encoder_hidden_states_mask,
"timestep": timestep,
"img_shapes": img_shapes,
}
def test_torch_compile_with_and_without_mask(self):
"""Test that torch.compile works with both None mask and padding mask."""
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile(mode="default", fullgraph=True)
# Test 1: Run with None mask (no padding, all tokens are valid)
inputs_no_mask = inputs.copy()
inputs_no_mask["encoder_hidden_states_mask"] = None
# First run to allow compilation
with torch.no_grad():
output_no_mask = model(**inputs_no_mask)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -300,19 +368,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_no_mask_2 = model(**inputs_no_mask)
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 2: Run with all-ones mask (should behave like None)
inputs_all_ones = inputs.copy()
# Keep the all-ones mask
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
# First run to allow compilation
with torch.no_grad():
output_all_ones = model(**inputs_all_ones)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -320,21 +384,18 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_all_ones_2 = model(**inputs_all_ones)
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Test 3: Run with actual padding mask (has zeros)
inputs_with_padding = inputs.copy()
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
mask_with_padding[:, 4:] = 0
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
# First run to allow compilation
with torch.no_grad():
output_with_padding = model(**inputs_with_padding)
# Second run to verify no recompilation
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
@@ -342,8 +403,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
):
output_with_padding_2 = model(**inputs_with_padding)
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
# Verify that outputs are different (mask should affect results)
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for QwenImage Transformer."""
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for QwenImage Transformer."""

View File

@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
# Override to set a more lenient max diff threshold.
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
def test_save_load_float16(self):
super().test_save_load_float16(expected_max_diff=0.03)
pass
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):

View File

@@ -139,7 +139,9 @@ class HunyuanVideoImageToVideoPipelineFastTests(
num_hidden_layers=2,
image_size=224,
)
llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101)
llava_text_encoder_config = LlavaConfig(
vision_config=vision_config, text_config=text_config, pad_token_id=100, image_token_index=101
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,