Compare commits

..

23 Commits

Author SHA1 Message Date
Sayak Paul
87e7b82b08 Merge branch 'main' into unet-model-tests-refactor 2026-04-12 12:23:51 +05:30
HsiaWinter
dc8d903217 Add ernie image (#13432)
* Add ERNIE-Image

* Update doc

* Update doc

* Change from Custom-Attention to Diffusers Style Attention

* Change from Custom-Attention to Diffusers Style Attention

* 兼容SGLang

* 优化PE模块的加载与offload策略

* 更新Doc文件与config配置相关内容

* Fix官方反馈的内容

* 根据官方建议优化代码

* Update code

* update

* update

* Apply style fixes

* update

* update

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-04-10 17:06:31 -10:00
Akshan Krithick
5a9a941a89 [modular] Add LTX Video modular pipeline (#13378)
* Add modular pipeline support for LTX Video

* Fix guidance_scale passthrough to guider

* Add LTX modular pipeline tests

* Add LTX image-to-video modular pipeline

* Fix i2v VAE dtype mismatch

* Add cache_context to denoiser for CFG parity

* Address review feedback

* Generate auto docstrings for LTX assembled blocks

* Fix ruff lint and format issues

* use InputParam/OutputParam templates and ruff check

* address all review

* Lift LTXVaeEncoderStep out of I2V core denoise into its own auto block

* unused declarations

* Fix check_copies: sync retrieve_timesteps and drop unsupported Copied from tags

* Address review: remove unused code, add LTXAutoBlocks, refactor I2V latents flow

* removed LTXBlocks,LTXImage2VideoBlocks

* Update test to use LTXAutoBlocks

* workflow map and auto docstring

* Add LTXVideoPachifier, workflow map

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-04-10 16:53:32 -10:00
Akshan Krithick
87beae7771 Fix HunyuanVideo 1.5 I2V by preprocessing image at pixel resolution i… (#13440)
Fix HunyuanVideo 1.5 I2V by preprocessing image at pixel resolution instead of latent resolution
2026-04-10 09:54:36 -10:00
Xyc2016
251676dfda Fix grammar in LoRA documentation (#13423)
Fix grammar in LoRA documentation (LoRA's → LoRAs, trigger it → trigger them)
2026-04-10 09:18:30 -07:00
Sayak Paul
896fec351b [tests] tighten dependency testing. (#13332)
* tighten dependency testing.

* invoke dependency testing temporarily.

* f
2026-04-10 18:12:12 +05:30
sayakpaul
bfebba420a resolve conflicts 2026-03-27 13:42:13 +05:30
Sayak Paul
17257b5d68 Merge branch 'main' into unet-model-tests-refactor 2026-03-23 15:27:34 +05:30
Sayak Paul
ecfd3b4f99 Merge branch 'main' into unet-model-tests-refactor 2026-02-16 16:37:07 +05:30
sayakpaul
5f8303fe3c remove test suites that are passed. 2026-02-16 16:36:06 +05:30
sayakpaul
99de4ceab8 [tests] refactor test_models_unet_spatiotemporal.py to use modular testing mixins
Refactored the spatiotemporal UNet test file to follow the modern modular testing
pattern with BaseModelTesterConfig and focused test classes:

- UNetSpatioTemporalTesterConfig: Base configuration with model setup
- TestUNetSpatioTemporal: Core model tests (ModelTesterMixin, UNetTesterMixin)
- TestUNetSpatioTemporalAttention: Attention-related tests (AttentionTesterMixin)
- TestUNetSpatioTemporalMemory: Memory/offloading tests (MemoryTesterMixin)
- TestUNetSpatioTemporalTraining: Training tests (TrainingTesterMixin)
- TestUNetSpatioTemporalLoRA: LoRA adapter tests (LoraTesterMixin)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:13:42 +05:30
sayakpaul
c6e6992cdd [tests] refactor test_models_unet_controlnetxs.py to use modular testing mixins
Refactor UNetControlNetXSModel tests to follow the modern testing
pattern with separate classes for core, memory, training, and LoRA.
Specialized tests (from_unet, freeze_unet, forward_no_control,
time_embedding_mixing) remain in the core test class.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:11:00 +05:30
sayakpaul
ecbaed793d [tests] refactor test_models_unet_3d_condition.py to use modular testing mixins
Refactor UNet3DConditionModel tests to follow the modern testing pattern
with separate classes for core, attention, memory, training, and LoRA.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:09:42 +05:30
sayakpaul
0411da7739 [tests] refactor test_models_unet_2d.py to use modular testing mixins
Refactor UNet2D model tests (standard, LDM, NCSN++) to follow the
modern testing pattern. Each variant gets its own config class and
dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:08:35 +05:30
sayakpaul
ffb254a273 [tests] refactor test_models_unet_1d.py to use modular testing mixins
Refactor UNet1D model tests to follow the modern testing pattern using
BaseModelTesterConfig and focused mixin classes (ModelTesterMixin,
MemoryTesterMixin, TrainingTesterMixin, LoraTesterMixin).

Both UNet1D standard and RL variants now have separate config classes
and dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:05:32 +05:30
sayakpaul
ea08148bbd recompile limit 2026-02-16 15:41:57 +05:30
sayakpaul
3a610814a3 Merge branch 'main' into unet-model-tests-refactor 2026-02-16 15:41:00 +05:30
sayakpaul
ca4a7b0649 up 2026-02-16 15:40:24 +05:30
sayakpaul
3371560f1d Revert "fix"
This reverts commit 46d44b73d8.
2026-02-16 13:34:24 +05:30
sayakpaul
46d44b73d8 fix 2026-02-16 13:30:54 +05:30
sayakpaul
2b67fb65ef up 2026-02-16 13:10:04 +05:30
sayakpaul
0e42a3ff93 fix tests 2026-02-16 11:59:33 +05:30
sayakpaul
14439ab793 refactor unet2d condition model tests. 2026-02-16 10:08:41 +05:30
41 changed files with 4126 additions and 890 deletions

View File

@@ -6,6 +6,7 @@ on:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
push:
branches:
- main

View File

@@ -6,6 +6,7 @@ on:
- main
paths:
- "src/diffusers/**.py"
- "tests/**.py"
push:
branches:
- main
@@ -26,7 +27,7 @@ jobs:
- name: Install dependencies
run: |
pip install -e .
pip install torch torchvision torchaudio pytest
pip install torch pytest
- name: Check for soft dependencies
run: |
pytest tests/others/test_dependencies.py

View File

@@ -350,6 +350,8 @@
title: DiTTransformer2DModel
- local: api/models/easyanimate_transformer3d
title: EasyAnimateTransformer3DModel
- local: api/models/ernie_image_transformer2d
title: ErnieImageTransformer2DModel
- local: api/models/flux2_transformer
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
@@ -534,6 +536,8 @@
title: DiT
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/ernie_image
title: ERNIE-Image
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/flux2

View File

@@ -0,0 +1,21 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# ErnieImageTransformer2DModel
A Transformer model for image-like data from [ERNIE-Image](https://huggingface.co/baidu/ERNIE-Image).
A Transformer model for image-like data from [ERNIE-Image-Turbo](https://huggingface.co/baidu/ERNIE-Image-Turbo).
## ErnieImageTransformer2DModel
[[autodoc]] ErnieImageTransformer2DModel

View File

@@ -0,0 +1,86 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Ernie-Image
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
[ERNIE-Image] is a powerful and highly efficient image generation model with 8B parameters. Currently there's only two models to be released:
|Model|Hugging Face|
|---|---|
|ERNIE-Image|https://huggingface.co/baidu/ERNIE-Image|
|ERNIE-Image-Turbo|https://huggingface.co/baidu/ERNIE-Image-Turbo|
## ERNIE-Image
ERNIE-Image is designed with a relatively compact architecture and solid instruction-following capability, emphasizing parameter efficiency. Based on an 8B DiT backbone, it provides performance that is comparable in some scenarios to larger (20B+) models, while maintaining reasonable parameter efficiency. It offers a relatively stable level of performance in instruction understanding and execution, text generation (e.g., English / Chinese / Japanese), and overall stability.
## ERNIE-Image-Turbo
ERNIE-Image-Turbo is a distilled variant of ERNIE-Image, requiring only 8 NFEs (Number of Function Evaluations) and offering a more efficient alternative with relatively comparable performance to the full model in certain cases.
## ErnieImagePipeline
Use [ErnieImagePipeline] to generate images from text prompts. The pipeline supports Prompt Enhancer (PE) by default, which enhances the users raw prompt to improve output quality, though it may reduce instruction-following accuracy.
We provide a pretrained 3B-parameter PE model; however, using larger language models (e.g., Gemini or ChatGPT) for prompt enhancement may yield better results. The system prompt template is available at: https://huggingface.co/baidu/ERNIE-Image/blob/main/pe/chat_template.jinja.
If you prefer not to use PE, set use_pe=False.
```python
import torch
from diffusers import ErnieImagePipeline
from diffusers.utils import load_image
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# If you are running low on GPU VRAM, you can enable offloading
pipe.enable_model_cpu_offload()
prompt = "一只黑白相间的中华田园犬"
images = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=4.0,
generator=torch.Generator("cuda").manual_seed(42),
use_pe=True,
).images
images[0].save("ernie-image-output.png")
```
```python
import torch
from diffusers import ErnieImagePipeline
from diffusers.utils import load_image
pipe = ErnieImagePipeline.from_pretrained("baidu/ERNIE-Image-Turbo", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# If you are running low on GPU VRAM, you can enable offloading
pipe.enable_model_cpu_offload()
prompt = "一只黑白相间的中华田园犬"
images = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=8,
guidance_scale=1.0,
generator=torch.Generator("cuda").manual_seed(42),
use_pe=True,
).images
images[0].save("ernie-image-turbo-output.png")
```

View File

@@ -101,9 +101,9 @@ export_to_video(video, "output.mp4", fps=16)
## LoRA
Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRA's](./tutorials/using_peft_for_inference) are the most popular.
Adapters insert a small number of trainable parameters to the original base model. Only the inserted parameters are fine-tuned while the rest of the model weights remain frozen. This makes it fast and cheap to fine-tune a model on a new style. Among adapters, [LoRAs](./tutorials/using_peft_for_inference) are the most popular.
Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRA's require a special word to trigger it, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
Add a LoRA to a pipeline with the [`~loaders.QwenImageLoraLoaderMixin.load_lora_weights`] method. Some LoRAs require a special word to trigger them, such as `Realism`, in the example below. Check a LoRA's model card to see if it requires a trigger word.
```py
import torch

View File

@@ -235,6 +235,7 @@ else:
"CosmosTransformer3DModel",
"DiTTransformer2DModel",
"EasyAnimateTransformer3DModel",
"ErnieImageTransformer2DModel",
"Flux2Transformer2DModel",
"FluxControlNetModel",
"FluxMultiControlNetModel",
@@ -455,6 +456,8 @@ else:
"HeliosPyramidDistilledAutoBlocks",
"HeliosPyramidDistilledModularPipeline",
"HeliosPyramidModularPipeline",
"LTXAutoBlocks",
"LTXModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
"QwenImageEditModularPipeline",
@@ -525,6 +528,7 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"ErnieImagePipeline",
"Flux2KleinKVPipeline",
"Flux2KleinPipeline",
"Flux2Pipeline",
@@ -1035,6 +1039,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CosmosTransformer3DModel,
DiTTransformer2DModel,
EasyAnimateTransformer3DModel,
ErnieImageTransformer2DModel,
Flux2Transformer2DModel,
FluxControlNetModel,
FluxMultiControlNetModel,
@@ -1234,6 +1239,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HeliosPyramidDistilledAutoBlocks,
HeliosPyramidDistilledModularPipeline,
HeliosPyramidModularPipeline,
LTXAutoBlocks,
LTXModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
@@ -1300,6 +1307,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
ErnieImagePipeline,
Flux2KleinKVPipeline,
Flux2KleinPipeline,
Flux2Pipeline,

View File

@@ -101,6 +101,7 @@ if is_torch_available():
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_ernie_image"] = ["ErnieImageTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
@@ -219,6 +220,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
DiTTransformer2DModel,
DualTransformer2DModel,
EasyAnimateTransformer3DModel,
ErnieImageTransformer2DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,

View File

@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)

View File

@@ -25,6 +25,7 @@ if is_torch_available():
from .transformer_cogview4 import CogView4Transformer2DModel
from .transformer_cosmos import CosmosTransformer3DModel
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_ernie_image import ErnieImageTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel

View File

@@ -0,0 +1,430 @@
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ernie-Image Transformer2DModel for HuggingFace Diffusers.
"""
import inspect
from dataclasses import dataclass
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ..attention import AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class ErnieImageTransformer2DModelOutput(BaseOutput):
sample: torch.Tensor
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: Tuple[int, int, int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = list(axes_dim)
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(2) # [B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
batch_size, dim, height, width = x.shape
return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous()
class ErnieImageSingleStreamAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"ErnieImageSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
freqs_cis: torch.Tensor | None = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
# Apply Norms
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE: same rotate_half logic as Megatron _apply_rotary_pos_emb_bshd (rotary_interleaved=False)
# x_in: [B, S, heads, head_dim], freqs_cis: [B, S, 1, head_dim] with angles [θ0,θ0,θ1,θ1,...]
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = torch.cos(freqs_cis).to(x.dtype)
sin_ = torch.sin(freqs_cis).to(x.dtype)
# Non-interleaved rotate_half: [-x2, x1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
# Compute joint attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape back
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = attn.to_out[0](hidden_states)
return output
class ErnieImageAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = ErnieImageSingleStreamAttnProcessor
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
qk_norm: str = "rms_norm",
added_proj_bias: bool | None = True,
out_bias: bool = True,
eps: float = 1e-5,
out_dim: int = None,
elementwise_affine: bool = True,
processor=None,
):
super().__init__()
self.head_dim = dim_head
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.query_dim = query_dim
self.out_dim = out_dim if out_dim is not None else query_dim
self.heads = out_dim // dim_head if out_dim is not None else heads
self.use_bias = bias
self.dropout = dropout
self.added_proj_bias = added_proj_bias
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
# QK Norm
if qk_norm == "layer_norm":
self.norm_q = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "rms_norm":
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
)
self.to_out = torch.nn.ModuleList([])
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
image_rotary_emb: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
if len(unused_kwargs) > 0:
logger.warning(
f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
)
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
return self.processor(self, hidden_states, attention_mask, image_rotary_emb, **kwargs)
class ErnieImageFeedForward(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int):
super().__init__()
# Separate gate and up projections (matches converted weights)
self.gate_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.up_proj = nn.Linear(hidden_size, ffn_hidden_size, bias=False)
self.linear_fc2 = nn.Linear(ffn_hidden_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x)))
class ErnieImageSharedAdaLNBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, qk_layernorm: bool = True
):
super().__init__()
self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps)
self.self_attention = ErnieImageAttention(
query_dim=hidden_size,
dim_head=hidden_size // num_heads,
heads=num_heads,
qk_norm="rms_norm" if qk_layernorm else None,
eps=eps,
bias=False,
out_bias=False,
processor=ErnieImageSingleStreamAttnProcessor(),
)
self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps)
self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size)
def forward(
self,
x,
rotary_pos_emb,
temb: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb
residual = x
x = self.adaLN_sa_ln(x)
x = (x.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
x_bsh = x.permute(1, 0, 2) # [S, B, H] → [B, S, H] for diffusers Attention (batch-first)
attn_out = self.self_attention(x_bsh, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
attn_out = attn_out.permute(1, 0, 2) # [B, S, H] → [S, B, H]
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
residual = x
x = self.adaLN_mlp_ln(x)
x = (x.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
return residual + (gate_mlp.float() * self.mlp(x).float()).to(x.dtype)
class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=eps)
self.linear = nn.Linear(hidden_size, hidden_size * 2)
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x)
# Broadcast conditioning to sequence dimension
x = x * (1 + scale.unsqueeze(0)) + shift.unsqueeze(0)
return x
class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
hidden_size: int = 3072,
num_attention_heads: int = 24,
num_layers: int = 24,
ffn_hidden_size: int = 8192,
in_channels: int = 128,
out_channels: int = 128,
patch_size: int = 1,
text_in_dim: int = 2560,
rope_theta: int = 256,
rope_axes_dim: Tuple[int, int, int] = (32, 48, 48),
eps: float = 1e-6,
qk_layernorm: bool = True,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.head_dim = hidden_size // num_attention_heads
self.num_layers = num_layers
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels
self.text_in_dim = text_in_dim
self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size)
self.text_proj = nn.Linear(text_in_dim, hidden_size, bias=False) if text_in_dim != hidden_size else None
self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False, downscale_freq_shift=0)
self.time_embedding = TimestepEmbedding(hidden_size, hidden_size)
self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size))
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
self.layers = nn.ModuleList(
[
ErnieImageSharedAdaLNBlock(
hidden_size, num_attention_heads, ffn_hidden_size, eps, qk_layernorm=qk_layernorm
)
for _ in range(num_layers)
]
)
self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps)
self.final_linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels)
nn.init.zeros_(self.final_linear.weight)
nn.init.zeros_(self.final_linear.bias)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
# encoder_hidden_states: List[torch.Tensor],
text_bth: torch.Tensor,
text_lens: torch.Tensor,
return_dict: bool = True,
):
device, dtype = hidden_states.device, hidden_states.dtype
B, C, H, W = hidden_states.shape
p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size
N_img = Hp * Wp
img_sbh = self.x_embedder(hidden_states).transpose(0, 1).contiguous()
# text_bth, text_lens = self._pad_text(encoder_hidden_states, device, dtype)
if self.text_proj is not None and text_bth.numel() > 0:
text_bth = self.text_proj(text_bth)
Tmax = text_bth.shape[1]
text_sbh = text_bth.transpose(0, 1).contiguous()
x = torch.cat([img_sbh, text_sbh], dim=0)
S = x.shape[0]
# Position IDs
text_ids = (
torch.cat(
[
torch.arange(Tmax, device=device, dtype=torch.float32).view(1, Tmax, 1).expand(B, -1, -1),
torch.zeros((B, Tmax, 2), device=device),
],
dim=-1,
)
if Tmax > 0
else torch.zeros((B, 0, 3), device=device)
)
grid_yx = torch.stack(
torch.meshgrid(
torch.arange(Hp, device=device, dtype=torch.float32),
torch.arange(Wp, device=device, dtype=torch.float32),
indexing="ij",
),
dim=-1,
).reshape(-1, 2)
image_ids = torch.cat(
[text_lens.float().view(B, 1, 1).expand(-1, N_img, -1), grid_yx.view(1, N_img, 2).expand(B, -1, -1)],
dim=-1,
)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
# Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention
valid_text = (
torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1)
if Tmax > 0
else torch.zeros((B, 0), device=device, dtype=torch.bool)
)
attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[
:, None, None, :
]
# AdaLN
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
]
for layer in self.layers:
temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp]
if torch.is_grad_enabled() and self.gradient_checkpointing:
x = self._gradient_checkpointing_func(
layer,
x,
rotary_pos_emb,
temb,
attention_mask,
)
else:
x = layer(x, rotary_pos_emb, temb, attention_mask)
x = self.final_norm(x, c).type_as(x)
patches = self.final_linear(x)[:N_img].transpose(0, 1).contiguous()
output = (
patches.view(B, Hp, Wp, p, p, self.out_channels)
.permute(0, 5, 1, 3, 2, 4)
.contiguous()
.view(B, self.out_channels, H, W)
)
return ErnieImageTransformer2DModelOutput(sample=output) if return_dict else (output,)

View File

@@ -88,6 +88,10 @@ else:
"QwenImageLayeredModularPipeline",
"QwenImageLayeredAutoBlocks",
]
_import_structure["ltx"] = [
"LTXAutoBlocks",
"LTXModularPipeline",
]
_import_structure["z_image"] = [
"ZImageAutoBlocks",
"ZImageModularPipeline",
@@ -119,6 +123,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HeliosPyramidDistilledModularPipeline,
HeliosPyramidModularPipeline,
)
from .ltx import LTXAutoBlocks, LTXModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_ltx"] = ["LTXAutoBlocks", "LTXBlocks", "LTXImage2VideoBlocks"]
_import_structure["modular_pipeline"] = ["LTXModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_ltx import LTXAutoBlocks, LTXBlocks, LTXImage2VideoBlocks
from .modular_pipeline import LTXModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,392 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import numpy as np
import torch
from ...configuration_utils import FrozenDict
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import LTXModularPipeline, LTXVideoPachifier
logger = logging.get_logger(__name__)
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
timesteps: list[int] | None = None,
sigmas: list[float] | None = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`list[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`list[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class LTXTextInputStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return (
"Input processing step that:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Adjusts input tensor shapes based on `batch_size` and `num_videos_per_prompt`"
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
InputParam.template("prompt_embeds", required=True),
InputParam.template("prompt_embeds_mask", name="prompt_attention_mask"),
InputParam.template("negative_prompt_embeds"),
InputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("batch_size", type_hint=int),
OutputParam("dtype", type_hint=torch.dtype),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
num_videos = block_state.num_videos_per_prompt
# Repeat prompt_embeds for num_videos_per_prompt
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, num_videos, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(block_state.batch_size * num_videos, seq_len, -1)
if block_state.prompt_attention_mask is not None:
block_state.prompt_attention_mask = block_state.prompt_attention_mask.repeat(num_videos, 1)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, num_videos, 1)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * num_videos, seq_len, -1
)
if block_state.negative_prompt_attention_mask is not None:
block_state.negative_prompt_attention_mask = block_state.negative_prompt_attention_mask.repeat(
num_videos, 1
)
self.set_block_state(state, block_state)
return components, state
class LTXSetTimestepsStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for inference"
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("num_inference_steps"),
InputParam.template("timesteps"),
InputParam.template("sigmas"),
InputParam.template("height", default=512),
InputParam.template("width", default=704),
InputParam("num_frames", type_hint=int, default=161),
InputParam("frame_rate", type_hint=int, default=25),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor),
OutputParam("num_inference_steps", type_hint=int),
OutputParam("rope_interpolation_scale", type_hint=tuple),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
height = block_state.height
width = block_state.width
num_frames = block_state.num_frames
frame_rate = block_state.frame_rate
latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1
latent_height = height // components.vae_spatial_compression_ratio
latent_width = width // components.vae_spatial_compression_ratio
video_sequence_length = latent_num_frames * latent_height * latent_width
custom_timesteps = block_state.timesteps
sigmas = block_state.sigmas
if custom_timesteps is not None:
# User provided custom timesteps, don't compute sigmas
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
device,
custom_timesteps,
)
else:
if sigmas is None:
sigmas = np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
mu = calculate_shift(
video_sequence_length,
components.scheduler.config.get("base_image_seq_len", 256),
components.scheduler.config.get("max_image_seq_len", 4096),
components.scheduler.config.get("base_shift", 0.5),
components.scheduler.config.get("max_shift", 1.15),
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
block_state.rope_interpolation_scale = (
components.vae_temporal_compression_ratio / frame_rate,
components.vae_spatial_compression_ratio,
components.vae_spatial_compression_ratio,
)
self.set_block_state(state, block_state)
return components, state
class LTXPrepareLatentsStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec(
"pachifier",
LTXVideoPachifier,
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("height", default=512),
InputParam.template("width", default=704),
InputParam("num_frames", type_hint=int, default=161),
InputParam.template("latents"),
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
InputParam.template("generator"),
InputParam.template("batch_size", required=True),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
num_channels_latents = components.transformer.config.in_channels
if block_state.latents is not None:
block_state.latents = block_state.latents.to(device=device, dtype=torch.float32)
else:
height = block_state.height // components.vae_spatial_compression_ratio
width = block_state.width // components.vae_spatial_compression_ratio
num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=torch.float32
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
class LTXImage2VideoPrepareLatentsStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return (
"Prepare image-to-video latents: adds noise to pre-encoded image latents and creates a conditioning mask. "
"Expects pure noise `latents` from LTXPrepareLatentsStep."
)
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec(
"pachifier",
LTXVideoPachifier,
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam("image_latents", type_hint=torch.Tensor, required=True),
InputParam.template("latents", required=True),
InputParam.template("height", default=512),
InputParam.template("width", default=704),
InputParam("num_frames", type_hint=int, default=161),
InputParam.template("num_images_per_prompt", name="num_videos_per_prompt"),
InputParam.template("batch_size", required=True),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam("latents", type_hint=torch.Tensor),
OutputParam("conditioning_mask", type_hint=torch.Tensor),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
height = block_state.height // components.vae_spatial_compression_ratio
width = block_state.width // components.vae_spatial_compression_ratio
num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
init_latents = block_state.image_latents.to(device=device, dtype=torch.float32)
if init_latents.shape[0] < batch_size:
init_latents = init_latents.repeat_interleave(batch_size // init_latents.shape[0], dim=0)
init_latents = init_latents.repeat(1, 1, num_frames, 1, 1)
conditioning_mask = torch.zeros(
init_latents.shape[0],
1,
init_latents.shape[2],
init_latents.shape[3],
init_latents.shape[4],
device=device,
dtype=torch.float32,
)
conditioning_mask[:, :, 0] = 1.0
noise = components.pachifier.unpack_latents(block_state.latents, num_frames, height, width)
latents = init_latents * conditioning_mask + noise * (1 - conditioning_mask)
conditioning_mask = components.pachifier.pack_latents(conditioning_mask).squeeze(-1)
latents = components.pachifier.pack_latents(latents)
block_state.latents = latents
block_state.conditioning_mask = conditioning_mask
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,132 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLLTXVideo
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import LTXVideoPachifier
logger = logging.get_logger(__name__)
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
class LTXVaeDecoderStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLLTXVideo),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 32}),
default_creation_method="from_config",
),
ComponentSpec(
"pachifier",
LTXVideoPachifier,
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into videos"
@property
def inputs(self) -> list[tuple[str, Any]]:
return [
InputParam.template("latents", required=True),
InputParam.template("output_type", default="np"),
InputParam.template("height", default=512),
InputParam.template("width", default=704),
InputParam("num_frames", type_hint=int, default=161),
InputParam("decode_timestep", default=0.0),
InputParam("decode_noise_scale", default=None),
InputParam.template("generator"),
InputParam.template("batch_size"),
InputParam.template("dtype", required=True),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [OutputParam.template("videos")]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae = components.vae
latents = block_state.latents
height = block_state.height
width = block_state.width
num_frames = block_state.num_frames
latent_num_frames = (num_frames - 1) // components.vae_temporal_compression_ratio + 1
latent_height = height // components.vae_spatial_compression_ratio
latent_width = width // components.vae_spatial_compression_ratio
latents = components.pachifier.unpack_latents(latents, latent_num_frames, latent_height, latent_width)
latents = _denormalize_latents(latents, vae.latents_mean, vae.latents_std, vae.config.scaling_factor)
latents = latents.to(block_state.dtype)
if not vae.config.timestep_conditioning:
timestep = None
else:
device = latents.device
batch_size = block_state.batch_size
decode_timestep = block_state.decode_timestep
decode_noise_scale = block_state.decode_noise_scale
noise = randn_tensor(latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = latents.to(vae.dtype)
video = vae.decode(latents, timestep, return_dict=False)[0]
block_state.videos = components.video_processor.postprocess_video(video, output_type=block_state.output_type)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,458 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import LTXVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam
from .modular_pipeline import LTXModularPipeline, LTXVideoPachifier
class LTXLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return (
"Step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `LTXDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("latents", required=True),
InputParam.template("dtype", required=True),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
return components, block_state
class LTXLoopDenoiser(ModularPipelineBlocks):
model_name = "ltx"
def __init__(
self,
guider_input_fields: dict[str, Any] | None = None,
):
if guider_input_fields is None:
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
}
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 3.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", LTXVideoTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `LTXDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> list[tuple[str, Any]]:
inputs = [
InputParam.template("attention_kwargs"),
InputParam.template("num_inference_steps", required=True),
InputParam("rope_interpolation_scale", type_hint=tuple),
InputParam.template("height"),
InputParam.template("width"),
InputParam("num_frames", type_hint=int),
]
guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.extend(value)
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs
@torch.no_grad()
def __call__(
self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
latent_height = block_state.height // components.vae_spatial_compression_ratio
latent_width = block_state.width // components.vae_spatial_compression_ratio
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
context_name = getattr(guider_state_batch, components.guider._identifier_key, None)
with components.transformer.cache_context(context_name):
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype),
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=block_state.rope_interpolation_scale,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class LTXLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "ltx"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that updates the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `LTXDenoiseLoopWrapper`)"
)
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class LTXDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoises the latents over `timesteps`. "
"The specific steps within each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", LTXVideoTransformer3DModel),
]
@property
def loop_inputs(self) -> list[InputParam]:
return [
InputParam.template("timesteps", required=True),
InputParam.template("num_inference_steps", required=True),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class LTXDenoiseStep(LTXDenoiseLoopWrapper):
block_classes = [
LTXLoopBeforeDenoiser,
LTXLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
}
),
LTXLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents.\n"
"Its loop logic is defined in `LTXDenoiseLoopWrapper.__call__` method.\n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `LTXLoopBeforeDenoiser`\n"
" - `LTXLoopDenoiser`\n"
" - `LTXLoopAfterDenoiser`\n"
"This block supports text-to-video tasks."
)
class LTXImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return (
"Step within the i2v denoising loop that prepares the latent input and modulates "
"the timestep with the conditioning mask."
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("latents", required=True),
InputParam("conditioning_mask", required=True, type_hint=torch.Tensor),
InputParam.template("dtype", required=True),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
block_state.timestep_adjusted = t.expand(block_state.latent_model_input.shape[0]).unsqueeze(-1) * (
1 - block_state.conditioning_mask
)
return components, block_state
class LTXImage2VideoLoopDenoiser(ModularPipelineBlocks):
model_name = "ltx"
def __init__(
self,
guider_input_fields: dict[str, Any] | None = None,
):
if guider_input_fields is None:
guider_input_fields = {
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
}
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 3.0}),
default_creation_method="from_config",
),
ComponentSpec("transformer", LTXVideoTransformer3DModel),
]
@property
def description(self) -> str:
return (
"Step within the i2v denoising loop that denoises the latents with guidance "
"using timestep modulated by the conditioning mask."
)
@property
def inputs(self) -> list[tuple[str, Any]]:
inputs = [
InputParam.template("attention_kwargs"),
InputParam.template("num_inference_steps", required=True),
InputParam("rope_interpolation_scale", type_hint=tuple),
InputParam.template("height"),
InputParam.template("width"),
InputParam("num_frames", type_hint=int),
]
guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.extend(value)
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor))
return inputs
@torch.no_grad()
def __call__(
self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
latent_height = block_state.height // components.vae_spatial_compression_ratio
latent_width = block_state.width // components.vae_spatial_compression_ratio
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
cond_kwargs = {
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
context_name = getattr(guider_state_batch, components.guider._identifier_key, None)
with components.transformer.cache_context(context_name):
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep_adjusted,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=block_state.rope_interpolation_scale,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class LTXImage2VideoLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "ltx"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec(
"pachifier",
LTXVideoPachifier,
config=FrozenDict({"patch_size": 1, "patch_size_t": 1}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return (
"Step within the i2v denoising loop that updates the latents, "
"applying the scheduler step only to frames after the first (conditioned) frame."
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("height"),
InputParam.template("width"),
InputParam("num_frames", type_hint=int),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1
latent_height = block_state.height // components.vae_spatial_compression_ratio
latent_width = block_state.width // components.vae_spatial_compression_ratio
noise_pred = components.pachifier.unpack_latents(
block_state.noise_pred, latent_num_frames, latent_height, latent_width
)
latents = components.pachifier.unpack_latents(
block_state.latents, latent_num_frames, latent_height, latent_width
)
noise_pred = noise_pred[:, :, 1:]
noise_latents = latents[:, :, 1:]
pred_latents = components.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0]
latents = torch.cat([latents[:, :, :1], pred_latents], dim=2)
block_state.latents = components.pachifier.pack_latents(latents)
return components, block_state
class LTXImage2VideoDenoiseStep(LTXDenoiseLoopWrapper):
block_classes = [
LTXImage2VideoLoopBeforeDenoiser,
LTXImage2VideoLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"),
}
),
LTXImage2VideoLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step for image-to-video that iteratively denoises the latents.\n"
"The first frame is kept fixed via a conditioning mask.\n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `LTXImage2VideoLoopBeforeDenoiser`\n"
" - `LTXImage2VideoLoopDenoiser`\n"
" - `LTXImage2VideoLoopAfterDenoiser`"
)

View File

@@ -0,0 +1,273 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import AutoencoderKLLTXVideo
from ...utils import logging
from ...video_processor import VideoProcessor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import LTXModularPipeline
logger = logging.get_logger(__name__)
def _get_t5_prompt_embeds(
components,
prompt: str | list[str],
max_sequence_length: int,
device: torch.device,
dtype: torch.dtype,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
prompt_embeds = components.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
class LTXTextEncoderStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings to guide the video generation"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("text_encoder", T5EncoderModel),
ComponentSpec("tokenizer", T5TokenizerFast),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 3.0}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("prompt"),
InputParam.template("negative_prompt"),
InputParam.template("max_sequence_length", default=128),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam.template("prompt_embeds"),
OutputParam.template("prompt_embeds_mask", name="prompt_attention_mask"),
OutputParam.template("negative_prompt_embeds"),
OutputParam.template("negative_prompt_embeds_mask", name="negative_prompt_attention_mask"),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
device: torch.device | None = None,
prepare_unconditional_embeds: bool = True,
negative_prompt: str | None = None,
max_sequence_length: int = 128,
):
device = device or components._execution_device
dtype = components.text_encoder.dtype
if not isinstance(prompt, list):
prompt = [prompt]
batch_size = len(prompt)
prompt_embeds, prompt_attention_mask = _get_t5_prompt_embeds(
components=components,
prompt=prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
negative_prompt_embeds = None
negative_prompt_attention_mask = None
if prepare_unconditional_embeds:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = _get_t5_prompt_embeds(
components=components,
prompt=negative_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
(
block_state.prompt_embeds,
block_state.prompt_attention_mask,
block_state.negative_prompt_embeds,
block_state.negative_prompt_attention_mask,
) = self.encode_prompt(
components=components,
prompt=block_state.prompt,
device=block_state.device,
prepare_unconditional_embeds=components.requires_unconditional_embeds,
negative_prompt=block_state.negative_prompt,
max_sequence_length=block_state.max_sequence_length,
)
self.set_block_state(state, block_state)
return components, state
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
class LTXVaeEncoderStep(ModularPipelineBlocks):
model_name = "ltx"
@property
def description(self) -> str:
return "VAE Encoder step that encodes an input image into latent space for image-to-video generation"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLLTXVideo),
ComponentSpec(
"video_processor",
VideoProcessor,
config=FrozenDict({"vae_scale_factor": 32}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> list[InputParam]:
return [
InputParam.template("image", required=True),
InputParam.template("height", default=512),
InputParam.template("width", default=704),
InputParam.template("generator"),
]
@property
def intermediate_outputs(self) -> list[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Encoded image latents from the VAE encoder",
),
]
@torch.no_grad()
def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
image = block_state.image
if not isinstance(image, torch.Tensor):
image = components.video_processor.preprocess(image, height=block_state.height, width=block_state.width)
image = image.to(device=device, dtype=torch.float32)
vae_dtype = components.vae.dtype
num_images = image.shape[0]
if isinstance(block_state.generator, list):
init_latents = [
retrieve_latents(
components.vae.encode(image[i].unsqueeze(0).unsqueeze(2).to(vae_dtype)),
block_state.generator[i],
)
for i in range(num_images)
]
else:
init_latents = [
retrieve_latents(
components.vae.encode(img.unsqueeze(0).unsqueeze(2).to(vae_dtype)),
block_state.generator,
)
for img in image
]
init_latents = torch.cat(init_latents, dim=0).to(torch.float32)
block_state.image_latents = _normalize_latents(
init_latents, components.vae.latents_mean, components.vae.latents_std
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -0,0 +1,487 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
LTXImage2VideoPrepareLatentsStep,
LTXPrepareLatentsStep,
LTXSetTimestepsStep,
LTXTextInputStep,
)
from .decoders import LTXVaeDecoderStep
from .denoise import LTXDenoiseStep, LTXImage2VideoDenoiseStep
from .encoders import LTXTextEncoderStep, LTXVaeEncoderStep
logger = logging.get_logger(__name__)
# auto_docstring
class LTXCoreDenoiseStep(SequentialPipelineBlocks):
"""
Denoise block that takes encoded conditions and runs the denoising process.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
Inputs:
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_attention_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_attention_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
timesteps (`Tensor`, *optional*):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "ltx"
block_classes = [
LTXTextInputStep,
LTXSetTimestepsStep,
LTXPrepareLatentsStep,
LTXDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return "Denoise block that takes encoded conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
# auto_docstring
class LTXImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
"""
Denoise block for image-to-video that takes encoded conditions and image latents, and runs the denoising process.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
Inputs:
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_attention_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_attention_mask (`Tensor`, *optional*):
mask for the negative text embeddings. Can be generated from text_encoder step.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
timesteps (`Tensor`, *optional*):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
image_latents (`Tensor`):
TODO: Add description.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "ltx"
block_classes = [
LTXTextInputStep,
LTXSetTimestepsStep,
LTXPrepareLatentsStep,
LTXImage2VideoPrepareLatentsStep,
LTXImage2VideoDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "prepare_i2v_latents", "denoise"]
@property
def description(self):
return "Denoise block for image-to-video that takes encoded conditions and image latents, and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
# auto_docstring
class LTXBlocks(SequentialPipelineBlocks):
"""
Modular pipeline blocks for LTX Video text-to-video.
Components:
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) scheduler
(`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) transformer
(`LTXVideoTransformer3DModel`) vae (`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`)
Inputs:
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 128):
Maximum sequence length for prompt encoding.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
timesteps (`Tensor`, *optional*):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
output_type (`str`, *optional*, defaults to np):
Output format: 'pil', 'np', 'pt'.
decode_timestep (`None`, *optional*, defaults to 0.0):
TODO: Add description.
decode_noise_scale (`None`, *optional*):
TODO: Add description.
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "ltx"
block_classes = [
LTXTextEncoderStep,
LTXCoreDenoiseStep,
LTXVaeDecoderStep,
]
block_names = ["text_encoder", "denoise", "decode"]
@property
def description(self):
return "Modular pipeline blocks for LTX Video text-to-video."
@property
def outputs(self):
return [OutputParam.template("videos")]
# auto_docstring
class LTXAutoVaeEncoderStep(AutoPipelineBlocks):
"""
VAE encoder step that encodes the image input into its latent representation.
This is an auto pipeline block that works for image-to-video tasks.
- `LTXVaeEncoderStep` is used when `image` is provided.
- If `image` is not provided, step will be skipped.
Components:
vae (`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`)
Inputs:
image (`Image | list`, *optional*):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
image_latents (`Tensor`):
Encoded image latents from the VAE encoder
"""
model_name = "ltx"
block_classes = [LTXVaeEncoderStep]
block_names = ["vae_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image input into its latent representation.\n"
"This is an auto pipeline block that works for image-to-video tasks.\n"
" - `LTXVaeEncoderStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
# auto_docstring
class LTXAutoCoreDenoiseStep(AutoPipelineBlocks):
"""
Auto denoise block that selects the appropriate denoise pipeline based on inputs.
- `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.
- `LTXCoreDenoiseStep` is used otherwise (text-to-video).
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) pachifier (`LTXVideoPachifier`) guider
(`ClassifierFreeGuidance`) transformer (`LTXVideoTransformer3DModel`)
Inputs:
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation. Can be generated from text_encoder step.
prompt_attention_mask (`Tensor`):
mask for the text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`):
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
negative_prompt_attention_mask (`Tensor`):
mask for the negative text embeddings. Can be generated from text_encoder step.
num_inference_steps (`int`):
The number of denoising steps.
timesteps (`Tensor`):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`):
Pre-generated noisy latents for image generation.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
image_latents (`Tensor`, *optional*):
TODO: Add description.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "ltx"
block_classes = [LTXImage2VideoCoreDenoiseStep, LTXCoreDenoiseStep]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Auto denoise block that selects the appropriate denoise pipeline based on inputs.\n"
" - `LTXImage2VideoCoreDenoiseStep` is used when `image_latents` is provided.\n"
" - `LTXCoreDenoiseStep` is used otherwise (text-to-video)."
)
# auto_docstring
class LTXAutoBlocks(SequentialPipelineBlocks):
"""
Auto blocks for LTX Video that support both text-to-video and image-to-video workflows.
Supported workflows:
- `text2video`: requires `prompt`
- `image2video`: requires `image`, `prompt`
Components:
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) vae
(`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`) scheduler (`FlowMatchEulerDiscreteScheduler`)
pachifier (`LTXVideoPachifier`) transformer (`LTXVideoTransformer3DModel`)
Inputs:
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 128):
Maximum sequence length for prompt encoding.
image (`Image | list`, *optional*):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`):
The number of denoising steps.
timesteps (`Tensor`):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`):
Pre-generated noisy latents for image generation.
image_latents (`Tensor`, *optional*):
TODO: Add description.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
output_type (`str`, *optional*, defaults to np):
Output format: 'pil', 'np', 'pt'.
decode_timestep (`None`, *optional*, defaults to 0.0):
TODO: Add description.
decode_noise_scale (`None`, *optional*):
TODO: Add description.
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "ltx"
block_classes = [
LTXTextEncoderStep,
LTXAutoVaeEncoderStep,
LTXAutoCoreDenoiseStep,
LTXVaeDecoderStep,
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
_workflow_map = {
"text2video": {"prompt": True},
"image2video": {"image": True, "prompt": True},
}
@property
def description(self):
return "Auto blocks for LTX Video that support both text-to-video and image-to-video workflows."
@property
def outputs(self):
return [OutputParam.template("videos")]
# auto_docstring
class LTXImage2VideoBlocks(SequentialPipelineBlocks):
"""
Modular pipeline blocks for LTX Video image-to-video.
Components:
text_encoder (`T5EncoderModel`) tokenizer (`T5TokenizerFast`) guider (`ClassifierFreeGuidance`) vae
(`AutoencoderKLLTXVideo`) video_processor (`VideoProcessor`) scheduler (`FlowMatchEulerDiscreteScheduler`)
pachifier (`LTXVideoPachifier`) transformer (`LTXVideoTransformer3DModel`)
Inputs:
prompt (`str`):
The prompt or prompts to guide image generation.
negative_prompt (`str`, *optional*):
The prompt or prompts not to guide the image generation.
max_sequence_length (`int`, *optional*, defaults to 128):
Maximum sequence length for prompt encoding.
image (`Image | list`, *optional*):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 704):
The width in pixels of the generated image.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
timesteps (`Tensor`, *optional*):
Timesteps for the denoising process.
sigmas (`list`, *optional*):
Custom sigmas for the denoising process.
num_frames (`int`, *optional*, defaults to 161):
TODO: Add description.
frame_rate (`int`, *optional*, defaults to 25):
TODO: Add description.
latents (`Tensor`, *optional*):
Pre-generated noisy latents for image generation.
image_latents (`Tensor`):
TODO: Add description.
attention_kwargs (`dict`, *optional*):
Additional kwargs for attention processors.
output_type (`str`, *optional*, defaults to np):
Output format: 'pil', 'np', 'pt'.
decode_timestep (`None`, *optional*, defaults to 0.0):
TODO: Add description.
decode_noise_scale (`None`, *optional*):
TODO: Add description.
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "ltx"
block_classes = [
LTXTextEncoderStep,
LTXAutoVaeEncoderStep,
LTXImage2VideoCoreDenoiseStep,
LTXVaeDecoderStep,
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return "Modular pipeline blocks for LTX Video image-to-video."
@property
def outputs(self):
return [OutputParam.template("videos")]

View File

@@ -0,0 +1,95 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import LTXVideoLoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__)
class LTXVideoPachifier(ConfigMixin):
"""
A class to pack and unpack latents for LTX Video.
"""
config_name = "config.json"
@register_to_config
def __init__(self, patch_size: int = 1, patch_size_t: int = 1):
super().__init__()
def pack_latents(self, latents: torch.Tensor) -> torch.Tensor:
batch_size, _, num_frames, height, width = latents.shape
patch_size = self.config.patch_size
patch_size_t = self.config.patch_size_t
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
def unpack_latents(self, latents: torch.Tensor, num_frames: int, height: int, width: int) -> torch.Tensor:
batch_size = latents.size(0)
patch_size = self.config.patch_size
patch_size_t = self.config.patch_size_t
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
class LTXModularPipeline(
ModularPipeline,
LTXVideoLoraLoaderMixin,
):
"""
A ModularPipeline for LTX Video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "LTXAutoBlocks"
@property
def vae_spatial_compression_ratio(self):
if getattr(self, "vae", None) is not None:
return self.vae.spatial_compression_ratio
return 32
@property
def vae_temporal_compression_ratio(self):
if getattr(self, "vae", None) is not None:
return self.vae.temporal_compression_ratio
return 8
@property
def requires_unconditional_embeds(self):
if hasattr(self, "guider") and self.guider is not None:
return self.guider._enabled and self.guider.num_conditions > 1
return False

View File

@@ -132,6 +132,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("z-image", _create_default_map_fn("ZImageModularPipeline")),
("helios", _create_default_map_fn("HeliosModularPipeline")),
("helios-pyramid", _helios_pyramid_map_fn),
("ltx", _create_default_map_fn("LTXModularPipeline")),
]
)

View File

@@ -335,6 +335,7 @@ else:
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["omnigen"] = ["OmniGenPipeline"]
_import_structure["ernie_image"] = ["ErnieImagePipeline"]
_import_structure["ovis_image"] = ["OvisImagePipeline"]
_import_structure["visualcloze"] = ["VisualClozePipeline", "VisualClozeGenerationPipeline"]
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
@@ -678,6 +679,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
)
from .ernie_image import ErnieImagePipeline
from .flux import (
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,

View File

@@ -5,10 +5,13 @@ import cv2
import numpy as np
import torch
from PIL import Image, ImageOps
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
from ...utils import get_logger, load_image
from ...utils import get_logger, is_torchvision_available, load_image
if is_torchvision_available():
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import normalize, resize
logger = get_logger(__name__)

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa: F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_ernie_image"] = ["ErnieImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_ernie_image import ErnieImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,389 @@
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Ernie-Image Pipeline for HuggingFace Diffusers.
"""
import json
from typing import Callable, List, Optional, Union
import torch
from PIL import Image
from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer
from ...models import AutoencoderKLFlux2
from ...models.transformers import ErnieImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils.torch_utils import randn_tensor
from .pipeline_output import ErnieImagePipelineOutput
class ErnieImagePipeline(DiffusionPipeline):
"""
Pipeline for text-to-image generation using ErnieImageTransformer2DModel.
This pipeline uses:
- A custom DiT transformer model
- A Flux2-style VAE for encoding/decoding latents
- A text encoder (e.g., Qwen) for text conditioning
- Flow Matching Euler Discrete Scheduler
"""
model_cpu_offload_seq = "pe->text_encoder->transformer->vae"
# For SGLang fallback ...
_optional_components = ["pe", "pe_tokenizer"]
_callback_tensor_inputs = ["latents"]
def __init__(
self,
transformer: ErnieImageTransformer2DModel,
vae: AutoencoderKLFlux2,
text_encoder: AutoModel,
tokenizer: AutoTokenizer,
scheduler: FlowMatchEulerDiscreteScheduler,
pe: Optional[AutoModelForCausalLM] = None,
pe_tokenizer: Optional[AutoTokenizer] = None,
):
super().__init__()
self.register_modules(
transformer=transformer,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
pe=pe,
pe_tokenizer=pe_tokenizer,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels)) if getattr(self, "vae", None) else 16
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@torch.no_grad()
def _enhance_prompt_with_pe(
self,
prompt: str,
device: torch.device,
width: int = 1024,
height: int = 1024,
system_prompt: Optional[str] = None,
temperature: float = 0.6,
top_p: float = 0.95,
) -> str:
"""Use PE model to rewrite/enhance a short prompt via chat_template."""
# Build user message as JSON carrying prompt text and target resolution
user_content = json.dumps(
{"prompt": prompt, "width": width, "height": height},
ensure_ascii=False,
)
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
# apply_chat_template picks up the chat_template.jinja loaded with pe_tokenizer
input_text = self.pe_tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False, # "Output:" is already in the user block
)
inputs = self.pe_tokenizer(input_text, return_tensors="pt").to(device)
output_ids = self.pe.generate(
**inputs,
max_new_tokens=self.pe_tokenizer.model_max_length,
do_sample=temperature != 1.0 or top_p != 1.0,
temperature=temperature,
top_p=top_p,
pad_token_id=self.pe_tokenizer.pad_token_id,
eos_token_id=self.pe_tokenizer.eos_token_id,
)
# Decode only newly generated tokens
generated_ids = output_ids[0][inputs["input_ids"].shape[1] :]
return self.pe_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
@torch.no_grad()
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: torch.device,
num_images_per_prompt: int = 1,
) -> List[torch.Tensor]:
"""Encode text prompts to embeddings."""
if isinstance(prompt, str):
prompt = [prompt]
text_hiddens = []
for p in prompt:
ids = self.tokenizer(
p,
add_special_tokens=True,
truncation=True,
padding=False,
)["input_ids"]
if len(ids) == 0:
if self.tokenizer.bos_token_id is not None:
ids = [self.tokenizer.bos_token_id]
else:
ids = [0]
input_ids = torch.tensor([ids], device=device)
with torch.no_grad():
outputs = self.text_encoder(
input_ids=input_ids,
output_hidden_states=True,
)
# Use second to last hidden state (matches training)
hidden = outputs.hidden_states[-2][0] # [T, H]
# Repeat for num_images_per_prompt
for _ in range(num_images_per_prompt):
text_hiddens.append(hidden)
return text_hiddens
@staticmethod
def _patchify_latents(latents: torch.Tensor) -> torch.Tensor:
"""2x2 patchify: [B, 32, H, W] -> [B, 128, H/2, W/2]"""
b, c, h, w = latents.shape
latents = latents.view(b, c, h // 2, 2, w // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
return latents.reshape(b, c * 4, h // 2, w // 2)
@staticmethod
def _unpatchify_latents(latents: torch.Tensor) -> torch.Tensor:
"""Reverse patchify: [B, 128, H/2, W/2] -> [B, 32, H, W]"""
b, c, h, w = latents.shape
latents = latents.reshape(b, c // 4, 2, 2, h, w)
latents = latents.permute(0, 1, 4, 2, 5, 3)
return latents.reshape(b, c // 4, h * 2, w * 2)
@staticmethod
def _pad_text(text_hiddens: List[torch.Tensor], device: torch.device, dtype: torch.dtype, text_in_dim: int):
B = len(text_hiddens)
if B == 0:
return torch.zeros((0, 0, text_in_dim), device=device, dtype=dtype), torch.zeros(
(0,), device=device, dtype=torch.long
)
normalized = [
th.squeeze(1).to(device).to(dtype) if th.dim() == 3 else th.to(device).to(dtype) for th in text_hiddens
]
lens = torch.tensor([t.shape[0] for t in normalized], device=device, dtype=torch.long)
Tmax = int(lens.max().item())
text_bth = torch.zeros((B, Tmax, text_in_dim), device=device, dtype=dtype)
for i, t in enumerate(normalized):
text_bth[i, : t.shape[0], :] = t
return text_bth, lens
@torch.no_grad()
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
negative_prompt: Optional[Union[str, List[str]]] = "",
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 4.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: list[torch.FloatTensor] | None = None,
negative_prompt_embeds: list[torch.FloatTensor] | None = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
use_pe: bool = True, # 默认使用PE进行改写
):
"""
Generate images from text prompts.
Args:
prompt: Text prompt(s)
negative_prompt: Negative prompt(s) for CFG. Default is "".
height: Image height in pixels (must be divisible by 16). Default: 1024.
width: Image width in pixels (must be divisible by 16). Default: 1024.
num_inference_steps: Number of denoising steps
guidance_scale: CFG scale (1.0 = no guidance). Default: 4.0.
num_images_per_prompt: Number of images per prompt
generator: Random generator for reproducibility
latents: Pre-generated latents (optional)
prompt_embeds: Pre-computed text embeddings for positive prompts (optional).
If provided, `encode_prompt` is skipped for positive prompts.
negative_prompt_embeds: Pre-computed text embeddings for negative prompts (optional).
If provided, `encode_prompt` is skipped for negative prompts.
output_type: "pil" or "latent"
return_dict: Whether to return a dataclass
callback_on_step_end: Optional callback invoked at the end of each denoising step.
Called as `callback_on_step_end(pipeline, step, timestep, callback_kwargs)` where `callback_kwargs`
contains the tensors listed in `callback_on_step_end_tensor_inputs`. The callback may return a dict to
override those tensors for subsequent steps.
callback_on_step_end_tensor_inputs: List of tensor names passed into the callback kwargs.
Must be a subset of `_callback_tensor_inputs` (default: `["latents"]`).
use_pe: Whether to use the PE model to enhance prompts before generation.
Returns:
:class:`ErnieImagePipelineOutput` with `images` and `revised_prompts`.
"""
device = self._execution_device
dtype = self.transformer.dtype
self._guidance_scale = guidance_scale
# Validate prompt / prompt_embeds
if prompt is None and prompt_embeds is None:
raise ValueError("Must provide either `prompt` or `prompt_embeds`.")
if prompt is not None and prompt_embeds is not None:
raise ValueError("Cannot provide both `prompt` and `prompt_embeds` at the same time.")
# Validate dimensions
if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0:
raise ValueError(f"Height and width must be divisible by {self.vae_scale_factor}")
# Handle prompts
if prompt is not None:
if isinstance(prompt, str):
prompt = [prompt]
# [Phase 1] PE: enhance prompts
revised_prompts: Optional[List[str]] = None
if prompt is not None and use_pe and self.pe is not None and self.pe_tokenizer is not None:
prompt = [self._enhance_prompt_with_pe(p, device, width=width, height=height) for p in prompt]
revised_prompts = list(prompt)
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
total_batch_size = batch_size * num_images_per_prompt
# Handle negative prompt
if negative_prompt is None:
negative_prompt = ""
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
if len(negative_prompt) != batch_size:
raise ValueError(f"negative_prompt must have same length as prompt ({batch_size})")
# [Phase 2] Text encoding
if prompt_embeds is not None:
text_hiddens = prompt_embeds
else:
text_hiddens = self.encode_prompt(prompt, device, num_images_per_prompt)
# CFG with negative prompt
if self.do_classifier_free_guidance:
if negative_prompt_embeds is not None:
uncond_text_hiddens = negative_prompt_embeds
else:
uncond_text_hiddens = self.encode_prompt(negative_prompt, device, num_images_per_prompt)
# Latent dimensions
latent_h = height // self.vae_scale_factor
latent_w = width // self.vae_scale_factor
latent_channels = self.transformer.config.in_channels # After patchify
# Initialize latents
if latents is None:
latents = randn_tensor(
(total_batch_size, latent_channels, latent_h, latent_w),
generator=generator,
device=device,
dtype=dtype,
)
# Setup scheduler
sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1)
self.scheduler.set_timesteps(sigmas=sigmas[:-1], device=device)
# Denoising loop
if self.do_classifier_free_guidance:
cfg_text_hiddens = list(uncond_text_hiddens) + list(text_hiddens)
else:
cfg_text_hiddens = text_hiddens
text_bth, text_lens = self._pad_text(
text_hiddens=cfg_text_hiddens, device=device, dtype=dtype, text_in_dim=self.transformer.config.text_in_dim
)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(self.scheduler.timesteps):
if self.do_classifier_free_guidance:
latent_model_input = torch.cat([latents, latents], dim=0)
t_batch = torch.full((total_batch_size * 2,), t.item(), device=device, dtype=dtype)
else:
latent_model_input = latents
t_batch = torch.full((total_batch_size,), t.item(), device=device, dtype=dtype)
# Model prediction
pred = self.transformer(
hidden_states=latent_model_input,
timestep=t_batch,
text_bth=text_bth,
text_lens=text_lens,
return_dict=False,
)[0]
# Apply CFG
if self.do_classifier_free_guidance:
pred_uncond, pred_cond = pred.chunk(2, dim=0)
pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
# Scheduler step
latents = self.scheduler.step(pred, t, latents).prev_sample
# Callback
if callback_on_step_end is not None:
callback_kwargs = {k: locals()[k] for k in callback_on_step_end_tensor_inputs}
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
progress_bar.update()
if output_type == "latent":
return latents
# Decode latents to images
# Unnormalize latents using VAE's BN stats
bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(device)
bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + 1e-5).to(device)
latents = latents * bn_std + bn_mean
# Unpatchify
latents = self._unpatchify_latents(latents)
# Decode
images = self.vae.decode(latents, return_dict=False)[0]
# Post-process
images = (images.clamp(-1, 1) + 1) / 2
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
images = [Image.fromarray((img * 255).astype("uint8")) for img in images]
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (images,)
return ErnieImagePipelineOutput(images=images, revised_prompts=revised_prompts)

View File

@@ -0,0 +1,36 @@
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional
import PIL.Image
from ...utils import BaseOutput
@dataclass
class ErnieImagePipelineOutput(BaseOutput):
"""
Output class for ERNIE-Image pipelines.
Args:
images (`List[PIL.Image.Image]`):
List of generated images.
revised_prompts (`List[str]`, *optional*):
List of PE-revised prompts. `None` when PE is disabled or unavailable.
"""
images: List[PIL.Image.Image]
revised_prompts: Optional[List[str]]

View File

@@ -611,7 +611,7 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
tuple: (cond_latents_concat, mask_concat) - both are zero tensors for t2v
"""
batch, channels, frames, height, width = latents.shape
batch, channels, frames, latent_height, latent_width = latents.shape
image_latents = self._get_image_latents(
vae=self.vae,
@@ -626,7 +626,7 @@ class HunyuanVideo15ImageToVideoPipeline(DiffusionPipeline):
latent_condition[:, :, 1:, :, :] = 0
latent_condition = latent_condition.to(device=device, dtype=dtype)
latent_mask = torch.zeros(batch, 1, frames, height, width, dtype=dtype, device=device)
latent_mask = torch.zeros(batch, 1, frames, latent_height, latent_width, dtype=dtype, device=device)
latent_mask[:, :, 0, :, :] = 1.0
return latent_condition, latent_mask

View File

@@ -1110,6 +1110,21 @@ class EasyAnimateTransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ErnieImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class Flux2Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -242,6 +242,36 @@ class HeliosPyramidModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTXAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1202,6 +1232,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ErnieImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinKVPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -465,7 +465,8 @@ class UNetTesterMixin:
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
@@ -480,9 +481,9 @@ class UNetTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
class ModelTesterMixin:

View File

@@ -287,8 +287,9 @@ class ModelTesterMixin:
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -308,8 +309,9 @@ class ModelTesterMixin:
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -337,8 +339,9 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict, return_dict=False)[0]
second = model(**inputs_dict, return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
@@ -395,8 +398,9 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
inputs_dict = self.get_dummy_inputs()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@@ -523,8 +527,10 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
@@ -563,8 +569,10 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
@@ -614,8 +622,10 @@ class ModelTesterMixin:
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"

View File

@@ -92,9 +92,6 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),

View File

@@ -0,0 +1,132 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
from diffusers import ErnieImageTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
# Ernie-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations.
# Cannot use enable_full_determinism() which sets it to True.
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
class ErnieImageTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return ErnieImageTransformer2DModel
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (16, 16, 16)
@property
def input_shape(self) -> tuple:
return (16, 16, 16)
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.9, 0.9, 0.9]
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_size": 16,
"num_attention_heads": 1,
"num_layers": 1,
"ffn_hidden_size": 16,
"in_channels": 16,
"out_channels": 16,
"patch_size": 1,
"text_in_dim": 16,
"rope_theta": 256,
"rope_axes_dim": (8, 4, 4),
"eps": 1e-6,
"qk_layernorm": True,
}
def get_dummy_inputs(self, height: int = 16, width: int = 16, batch_size: int = 1) -> dict:
num_channels = 16 # in_channels
sequence_length = 16
text_in_dim = 16 # text_in_dim
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0] * batch_size, device=torch_device),
"text_bth": randn_tensor(
(batch_size, sequence_length, text_in_dim), generator=self.generator, device=torch_device
),
"text_lens": torch.tensor([sequence_length] * batch_size, device=torch_device),
}
class TestErnieImageTransformer(ErnieImageTransformerTesterConfig, ModelTesterMixin):
pass
class TestErnieImageTransformerTraining(ErnieImageTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"ErnieImageTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestErnieImageTransformerCompile(ErnieImageTransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
@pytest.mark.skip(
reason="The repeated block in this model is ErnieImageSharedAdaLNBlock. As a consequence of this, "
"the inputs recorded for the block would vary during compilation and full compilation with "
"fullgraph=True would trigger recompilation."
)
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()
@pytest.mark.skip(reason="Fullgraph AoT is broken.")
def test_compile_works_with_aot(self, tmp_path):
super().test_compile_works_with_aot(tmp_path)
@pytest.mark.skip(reason="Fullgraph is broken.")
def test_compile_on_different_shapes(self):
super().test_compile_on_different_shapes()

View File

@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
@@ -26,64 +24,39 @@ from ...testing_utils import (
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
)
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
_LAYERWISE_CASTING_XFAIL_REASON = (
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
)
class UNet1DTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (standard variant)."""
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
def model_class(self):
return UNet1DModel
@property
def output_shape(self):
return (4, 14, 16)
return (14, 16)
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@property
def main_input_name(self):
return "sample"
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def test_determinism(self):
super().test_determinism()
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
@@ -97,18 +70,40 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "swish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DHubLoading(UNet1DTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -131,12 +126,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
@slow
def test_unet_1d_maestro(self):
@@ -157,98 +147,29 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
# =============================================================================
# UNet1D RL (Value Function) Model Tests
# =============================================================================
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
class UNet1DRLTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (RL value function variant)."""
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
def model_class(self):
return UNet1DModel
@property
def output_shape(self):
return (4, 14, 1)
return (1,)
def test_determinism(self):
super().test_determinism()
@property
def main_input_name(self):
return "sample"
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
@@ -264,18 +185,54 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
@torch.no_grad()
def test_output(self):
# UNetRL is a value-function with different output shape (batch, 1)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
assert value_function is not None
assert len(vf_loading_info["missing_keys"]) == 0
value_function.to(torch_device)
image = value_function(**self.dummy_input)
image = value_function(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -299,31 +256,4 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
assert torch.allclose(output, expected_output_slice, rtol=1e-3)

View File

@@ -15,12 +15,11 @@
import gc
import math
import unittest
import pytest
import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from ...testing_utils import (
backend_empty_cache,
@@ -31,39 +30,40 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# Standard UNet2D Model Tests
# =============================================================================
class UNet2DTesterConfig(BaseModelTesterConfig):
"""Base configuration for standard UNet2DModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
@@ -74,11 +74,22 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 4
@@ -87,13 +98,11 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
assert model.mid_block.attentions[0].group_norm is not None, (
"Mid block Attention group norm should exist but does not."
)
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
"Mid block Attention group norm does not have the expected number of groups."
)
with torch.no_grad():
@@ -102,13 +111,15 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
mid_none_init_dict = self.get_init_dict()
mid_none_inputs_dict = self.get_dummy_inputs()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
@@ -119,7 +130,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
mid_none_model.to(torch_device)
mid_none_model.eval()
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
assert mid_none_model.mid_block is None, "Mid block should not exist."
with torch.no_grad():
output = model(**inputs_dict)
@@ -133,8 +144,10 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
@@ -143,41 +156,32 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"UpBlock2D",
"DownBlock2D",
}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 8
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# UNet2D LDM Model Tests
# =============================================================================
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel LDM variant testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
@@ -187,17 +191,34 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.get_dummy_inputs()).sample
assert image is not None, "Make sure output is not None"
@@ -205,7 +226,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.get_dummy_inputs()).sample
assert image is not None, "Make sure output is not None"
@@ -265,44 +286,31 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 32
block_out_channels = (32, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# NCSN++ Model Tests
# =============================================================================
class NCSNppTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel NCSN++ variant testing."""
@property
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": [32, 64, 64, 64],
"in_channels": 3,
"layers_per_block": 1,
@@ -324,17 +332,71 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"SkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
}
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_keep_in_fp32_modules(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_from_save_pretrained_dtype_inference(self):
pass
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_training(self):
pass
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestNCSNppHubLoading(NCSNppTesterConfig):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
inputs = self.dummy_input
inputs = self.get_dummy_inputs()
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
@@ -361,7 +423,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
@@ -382,35 +444,4 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# not required for this model
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
block_out_channels = (32, 64, 64, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, block_out_channels=block_out_channels
)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)

View File

@@ -20,6 +20,7 @@ import tempfile
import unittest
from collections import OrderedDict
import pytest
import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized
@@ -52,17 +53,24 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import (
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
UNetTesterMixin,
TrainingTesterMixin,
)
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from ..testing_utils.lora import check_if_lora_correctly_set
logger = logging.get_logger(__name__)
@@ -82,16 +90,6 @@ def get_unet_lora_config():
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
@@ -354,34 +352,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.34, 0.4]
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DConditionModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def model_class(self):
return UNet2DConditionModel
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int, int]:
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
def model_split_percents(self) -> list[float]:
return [0.5, 0.34, 0.4]
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self) -> str:
return "sample"
def get_init_dict(self) -> dict:
"""Return UNet2D model initialization arguments."""
return {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
@@ -393,26 +385,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
"""Return dummy inputs for UNet2D model."""
batch_size = 4
num_channels = 4
sizes = (16, 16)
model.enable_xformers_memory_efficient_attention()
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -427,12 +417,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["use_linear_projection"] = True
@@ -446,12 +437,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["cross_attention_dim"] = (8, 8)
@@ -465,12 +457,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_simple_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -489,12 +482,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_class_embeddings_concat(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -514,12 +508,287 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
"""Hub checkpoint loading tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for UNet2DConditionModel."""
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
"""Test that deprecated load_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
"""Test that deprecated save_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
model.save_attn_procs(os.path.join(tmpdirname))
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for UNet2DConditionModel."""
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
"""Training tests for UNet2DConditionModel."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
"""Attention processor tests for UNet2DConditionModel."""
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -544,7 +813,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert output is not None
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -562,21 +831,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
@@ -618,7 +872,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -645,7 +900,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
]
)
def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
@@ -675,39 +931,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
)
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
"""Custom Diffusion processor tests for UNet2DConditionModel."""
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -733,8 +963,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample1 - sample2).abs().max() < 3e-3
def test_custom_diffusion_save_load(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -754,7 +984,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
@@ -773,8 +1003,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_custom_diffusion_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -798,41 +1028,28 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for UNet2DConditionModel."""
model = self.model_class(**init_dict)
model.to(torch_device)
@property
def ip_adapter_processor_cls(self):
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
with torch.no_grad():
sample = model(**inputs_dict).sample
def create_ip_adapter_state_dict(self, model):
return create_ip_adapter_state_dict(model)
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
return inputs_dict
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -905,7 +1122,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -977,185 +1195,16 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)
warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
def test_torch_compile_repeated_blocks(self):
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for UNet2DConditionModel."""
@slow

View File

@@ -18,47 +18,44 @@ import unittest
import numpy as np
import torch
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers import UNet3DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
)
enable_full_determinism()
logger = logging.get_logger(__name__)
@skip_mps
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet3DConditionModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 4, 16, 16)
def model_class(self):
return UNet3DConditionModel
@property
def output_shape(self):
return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
@@ -73,27 +70,25 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
model.enable_xformers_memory_efficient_attention()
return {
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
@@ -107,39 +102,74 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
# Overriding since the UNet3D outputs a different structure.
@torch.no_grad()
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
assert max_diff <= 1e-5
def test_feed_forward_chunking(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
assert output.shape == output_2.shape, "Shape doesn't match"
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
@@ -162,22 +192,3 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2

View File

@@ -13,59 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetControlNetXSModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
def model_class(self):
return UNetControlNetXSModel
@property
def output_shape(self):
return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
@@ -80,11 +63,23 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
"conditioning_scale": 1,
}
def get_dummy_unet(self):
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
@@ -99,10 +94,16 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
"""Build the ControlNetXS-Adapter from a UNet."""
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
pass
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
@@ -115,7 +116,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
@@ -152,7 +153,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
@@ -193,12 +194,12 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
for p in module.parameters():
assert p.requires_grad
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
@@ -236,7 +237,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_frozen(u.upsamplers)
# # check controlnet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
@@ -267,16 +268,6 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):
unet = self.get_dummy_unet()
@@ -287,7 +278,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.dummy_input
input_ = self.get_dummy_inputs()
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
@@ -312,7 +303,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.dummy_input
input_ = self.get_dummy_inputs()
with torch.no_grad():
output = model(**input_).sample
@@ -320,7 +311,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert output.shape == output_mix_time.shape
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -16,10 +16,10 @@
import copy
import unittest
import pytest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
@@ -28,45 +28,34 @@ from ...testing_utils import (
skip_mps,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
@skip_mps
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
@property
def dummy_input(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
@property
def input_shape(self):
return (2, 2, 4, 32, 32)
def model_class(self):
return UNetSpatioTemporalConditionModel
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
@property
def fps(self):
return 6
@@ -83,8 +72,8 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
def addition_time_embed_dim(self):
return 32
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
@@ -103,8 +92,23 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
"addition_time_embed_dim": self.addition_time_embed_dim,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
def _get_add_time_ids(self, do_classifier_free_guidance=True):
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
@@ -124,43 +128,15 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
return add_time_ids
@unittest.skip("Number of Norm Groups is not configurable")
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Number of Norm Groups is not configurable")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("Deprecated functionality")
def test_model_attention_slicing(self):
pass
@unittest.skip("Not supported")
def test_model_with_use_linear_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_simple_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_class_embeddings_concat(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_with_num_attention_heads_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
@@ -173,12 +149,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["cross_attention_dim"] = (32, 32)
@@ -192,27 +169,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
num_attention_heads = (8, 16)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, num_attention_heads=num_attention_heads
)
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["num_attention_heads"] = (8, 16)
@@ -225,3 +188,33 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

View File

@@ -0,0 +1,72 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from diffusers.modular_pipelines import LTXAutoBlocks, LTXModularPipeline
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
LTX_WORKFLOWS = {
"text2video": [
("text_encoder", "LTXTextEncoderStep"),
("denoise.input", "LTXTextInputStep"),
("denoise.set_timesteps", "LTXSetTimestepsStep"),
("denoise.prepare_latents", "LTXPrepareLatentsStep"),
("denoise.denoise", "LTXDenoiseStep"),
("decode", "LTXVaeDecoderStep"),
],
"image2video": [
("text_encoder", "LTXTextEncoderStep"),
("vae_encoder", "LTXVaeEncoderStep"),
("denoise.input", "LTXTextInputStep"),
("denoise.set_timesteps", "LTXSetTimestepsStep"),
("denoise.prepare_latents", "LTXPrepareLatentsStep"),
("denoise.prepare_i2v_latents", "LTXImage2VideoPrepareLatentsStep"),
("denoise.denoise", "LTXImage2VideoDenoiseStep"),
("decode", "LTXVaeDecoderStep"),
],
}
class TestLTXModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = LTXModularPipeline
pipeline_blocks_class = LTXAutoBlocks
pretrained_model_name_or_path = "akshan-main/tiny-ltx-modular-pipe"
params = frozenset(["prompt", "height", "width", "num_frames"])
batch_params = frozenset(["prompt"])
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
expected_workflow_blocks = LTX_WORKFLOWS
output_name = "videos"
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
@pytest.mark.skip(reason="num_videos_per_prompt")
def test_num_images_per_prompt(self):
pass

View File

@@ -13,16 +13,14 @@
# limitations under the License.
import inspect
import unittest
from importlib import import_module
import pytest
class DependencyTester(unittest.TestCase):
class TestDependencies:
def test_diffusers_import(self):
try:
import diffusers # noqa: F401
except ImportError:
assert False
import diffusers # noqa: F401
def test_backend_registration(self):
import diffusers
@@ -52,3 +50,36 @@ class DependencyTester(unittest.TestCase):
if hasattr(diffusers.pipelines, cls_name):
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
_ = import_module(pipeline_folder_module, str(cls_name))
def test_pipeline_module_imports(self):
"""Import every pipeline submodule whose dependencies are satisfied,
to catch unguarded optional-dep imports (e.g., torchvision).
Uses inspect.getmembers to discover classes that the lazy loader can
actually resolve (same self-filtering as test_pipeline_imports), then
imports the full module path instead of truncating to the folder level.
"""
import diffusers
import diffusers.pipelines
failures = []
all_classes = inspect.getmembers(diffusers, inspect.isclass)
for cls_name, cls_module in all_classes:
if not hasattr(diffusers.pipelines, cls_name):
continue
if "dummy_" in cls_module.__module__:
continue
full_module_path = cls_module.__module__
try:
import_module(full_module_path)
except ImportError as e:
failures.append(f"{full_module_path}: {e}")
except Exception:
# Non-import errors (e.g., missing config) are fine; we only
# care about unguarded import statements.
pass
if failures:
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))