Compare commits

..

7 Commits

Author SHA1 Message Date
yiyi@huggingface.co
7f97055d5f fix test 2026-02-02 07:34:45 +00:00
yiyi@huggingface.co
85d8d244a1 upup more 2026-02-02 07:21:07 +00:00
yiyi@huggingface.co
7d272bed80 fix copies 2026-02-02 07:15:27 +00:00
yiyi@huggingface.co
e95c5a2609 copies 2026-02-02 06:01:14 +00:00
yiyi@huggingface.co
80275cf18b style 2026-02-02 06:00:38 +00:00
yiyi@huggingface.co
2b74061a11 fix init_pipeline etc 2026-02-02 05:59:05 +00:00
yiyi@huggingface.co
23fb285912 initil 2026-02-02 02:39:04 +00:00
39 changed files with 983 additions and 7125 deletions

View File

@@ -343,34 +343,6 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
### Ulysses Anything Attention
The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.
[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].
```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
```
> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.
We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|--------------------|------------------|-------------|------------------|------------|
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
| ulysses | failed | failed | failed | 1008x1008 |
| ring | failed | failed | failed | 1008x1008 |
| unified_balanced | failed | failed | failed | 1008x1008 |
| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.

View File

@@ -415,6 +415,7 @@ else:
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinBaseModularPipeline",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
@@ -431,8 +432,13 @@ else:
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks",
"Wan22Blocks",
"Wan22Image2VideoBlocks",
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanBlocks",
"WanImage2VideoAutoBlocks",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
"ZImageAutoBlocks",
"ZImageModularPipeline",
@@ -1151,6 +1157,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
@@ -1167,8 +1174,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks,
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
ZImageAutoBlocks,
ZImageModularPipeline,

View File

@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Dict, List, Tuple, Type, Union
from typing import Dict, List, Type, Union
import torch
import torch.distributed as dist
if torch.distributed.is_available():
@@ -29,10 +27,9 @@ from ..models._modeling_parallel import (
ContextParallelInput,
ContextParallelModelPlan,
ContextParallelOutput,
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -211,10 +208,6 @@ class ContextParallelSplitHook(ModelHook):
)
return x
else:
if self.parallel_config.ulysses_anything:
return PartitionAnythingSharder.shard_anything(
x, cp_input.split_dim, self.parallel_config._flattened_mesh
)
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
@@ -240,14 +233,7 @@ class ContextParallelGatherHook(ModelHook):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
if self.parallel_config.ulysses_anything:
output[i] = PartitionAnythingSharder.unshard_anything(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
else:
output[i] = EquipartitionSharder.unshard(
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
)
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
return output[0] if is_tensor else tuple(output)
@@ -288,73 +274,6 @@ class EquipartitionSharder:
return tensor
class AllGatherAnythingFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
ctx.dim = dim
ctx.group = group
ctx.world_size = dist.get_world_size(group)
ctx.rank = dist.get_rank(group)
gathered_tensor = _all_gather_anything(tensor, dim, group)
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
return grad_splits[ctx.rank], None, None
class PartitionAnythingSharder:
@classmethod
def shard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
assert tensor.size()[dim] >= mesh.size(), (
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
)
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
# function may return fewer than the specified number of chunks!
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
@classmethod
def unshard_anything(
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.Tensor:
tensor = tensor.contiguous()
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
return tensor
@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
gather_shapes = []
for i in range(world_size):
rank_shape = list(copy.deepcopy(shape))
rank_shape[dim] = gather_dims[i]
gather_shapes.append(rank_shape)
return gather_shapes
@maybe_allow_in_graph
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
world_size = dist.get_world_size(group=group)
tensor = tensor.contiguous()
shape = tensor.shape
rank_dim = shape[dim]
gather_dims = gather_size_by_comm(rank_dim, group)
gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
dist.all_gather(gathered_tensors, tensor, group=group)
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
return gathered_tensor
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
if name.count("*") > 1:
raise ValueError("Wildcard '*' can only be used once in the name")

View File

@@ -2321,14 +2321,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
if key.startswith("single_blocks."):
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
elif key.startswith("double_blocks."):
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
num_double_layers = 8
num_single_layers = 48
lora_keys = ("lora_A", "lora_B")
attn_types = ("img_attn", "txt_attn")

View File

@@ -19,7 +19,6 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
import torch
import torch.distributed as dist
from ..utils import get_logger
@@ -68,9 +67,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
_rank: int = None
_world_size: int = None
@@ -98,11 +94,6 @@ class ContextParallelConfig:
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
if self.ulysses_anything:
if self.ulysses_degree == 1:
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
if self.ring_degree > 1:
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
@property
def mesh_shape(self) -> Tuple[int, int]:
@@ -266,39 +257,3 @@ ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextPara
#
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
# Below are utility functions for distributed communication in context parallelism.
def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
r"""Gather the local size from all ranks.
size: int, local size return: List[int], list of size from all ranks
"""
# NOTE(Serving/CP Safety):
# Do NOT cache this collective result.
#
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
# may legitimately differ across ranks. If we cache based on the *local* `size`,
# different ranks can have different cache hit/miss patterns across time.
#
# That can lead to a catastrophic distributed hang:
# - some ranks hit cache and *skip* dist.all_gather()
# - other ranks miss cache and *enter* dist.all_gather()
# This mismatched collective participation will stall the process group and
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
# timeouts in Ulysses attention).
world_size = dist.get_world_size(group=group)
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
comm_backends = str(dist.get_backend(group=group))
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
dist.all_gather(
gathered_sizes,
torch.tensor([size], device=gather_device, dtype=torch.int64),
group=group,
)
gathered_sizes = [s[0].item() for s in gathered_sizes]
# NOTE: DON'T use tolist here due to graph break - Explanation:
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
return gathered_sizes

View File

@@ -21,8 +21,6 @@ from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn.functional as F
if torch.distributed.is_available():
@@ -46,8 +44,6 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
if TYPE_CHECKING:
@@ -1283,154 +1279,6 @@ class SeqAllToAllDim(torch.autograd.Function):
return (None, grad_input, None, None)
# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention.
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
"""
world_size = dist.get_world_size(group=group)
H_PAD = 0
if H % world_size != 0:
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD
def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
"""
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
# Only the last rank may have padding
if H_PAD > 0 and rank == world_size - 1:
x = x[:, :, :-H_PAD, :]
return x.contiguous()
def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
r"""Maybe pad the head dimension to be divisible by world_size.
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
"""
if H is None:
return x, 0
rank = dist.get_rank(group=group)
world_size = dist.get_world_size(group=group)
H_PAD = 0
# Only the last rank may need padding
if H % world_size != 0:
# We need to broadcast H_PAD to all ranks to keep consistency
# in unpadding step later for all ranks.
H_PAD = world_size - (H % world_size)
NEW_H_LOCAL = (H + H_PAD) // world_size
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
if rank == world_size - 1:
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
return x, H_PAD
def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
r"""Maybe unpad the head dimension.
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
"""
if H_PAD > 0:
x = x[:, :, :-H_PAD, :]
return x.contiguous()
def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
# query: (B, S_LOCAL, H_GLOBAL, D)
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
extra_kwargs = {}
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
# Add other kwargs if needed in future
return extra_kwargs
@maybe_allow_in_graph
def all_to_all_single_any_qkv_async(
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
"""
world_size = dist.get_world_size(group=group)
B, S_LOCAL, H, D = x.shape
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
H_LOCAL = (H + H_PAD) // world_size
# (world_size, S_LOCAL, B, H_LOCAL, D)
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
input_split_sizes = [S_LOCAL] * world_size
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
# since we don't know the actual shape before this timing, thus,
# we have to use all gather to collect the S_LOCAL first.
output_split_sizes = gather_size_by_comm(S_LOCAL, group)
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
# (S_GLOBAL, B, H_LOCAL, D)
# -> (B, S_GLOBAL, H_LOCAL, D)
x = x.permute(1, 0, 2, 3).contiguous()
x = _maybe_unpad_qkv_head(x, H_PAD, group)
return x
return wait
@maybe_allow_in_graph
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
r"""
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
"""
# Assume H is provided in kwargs, since we can't infer H from x's shape.
# The padding logic needs H to determine if padding is necessary.
H = kwargs.get("NUM_QO_HEAD", None)
world_size = dist.get_world_size(group=group)
x, H_PAD = _maybe_pad_o_head(x, H, group)
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
(B, S_GLOBAL, H_LOCAL, D) = shape
# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]
# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
# b.tensor_split(4)[0].shape[1])
S_LOCAL = kwargs.get("Q_S_LOCAL")
input_split_sizes = gather_size_by_comm(S_LOCAL, group)
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
output_split_sizes = [S_LOCAL] * world_size
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
def wait() -> torch.Tensor:
nonlocal x, H_PAD
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
x = x.permute(2, 1, 0, 3, 4).contiguous()
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
x = _maybe_unpad_o_head(x, H_PAD, group)
return x
return wait
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1653,82 +1501,6 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
@staticmethod
def forward(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
**kwargs,
):
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
group = ulysses_mesh.get_group()
ctx.forward_op = forward_op
ctx.backward_op = backward_op
ctx._parallel_config = _parallel_config
metadata = ulysses_anything_metadata(query)
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
value_wait = all_to_all_single_any_qkv_async(value, group, **metadata)
query = query_wait() # type: torch.Tensor
key = key_wait() # type: torch.Tensor
value = value_wait() # type: torch.Tensor
out = forward_op(
ctx,
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
_save_ctx=False, # ulysses anything only support forward pass now.
_parallel_config=_parallel_config,
)
if return_lse:
out, lse, *_ = out
# out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
out_wait = all_to_all_single_any_o_async(out, group, **metadata)
if return_lse:
# lse: (B, S_Q_GLOBAL, H_LOCAL)
lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
lse_wait = all_to_all_single_any_o_async(lse, group, **metadata)
out = out_wait() # type: torch.Tensor
lse = lse_wait() # type: torch.Tensor
lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
else:
out = out_wait() # type: torch.Tensor
lse = None
return (out, lse) if return_lse else out
@staticmethod
def backward(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
):
raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.")
def _templated_unified_attention(
query: torch.Tensor,
key: torch.Tensor,
@@ -1846,37 +1618,20 @@ def _templated_context_parallel_attention(
_parallel_config,
)
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
if _parallel_config.context_parallel_config.ulysses_anything:
# For Any sequence lengths and Any head num support
return TemplatedUlyssesAnythingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
return TemplatedUlyssesAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
return TemplatedUlyssesAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
else:
raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")

View File

@@ -45,7 +45,16 @@ else:
"InsertableDict",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["wan"] = [
"WanBlocks",
"Wan22Blocks",
"WanImage2VideoAutoBlocks",
"Wan22Image2VideoBlocks",
"WanModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"Wan22Image2VideoModularPipeline",
]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
@@ -58,6 +67,7 @@ else:
"Flux2KleinBaseAutoBlocks",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
@@ -88,6 +98,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinBaseModularPipeline,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
)
@@ -112,7 +123,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
from .wan import (
Wan22Blocks,
Wan22Image2VideoBlocks,
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanBlocks,
WanImage2VideoAutoBlocks,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
else:
import sys

View File

@@ -55,7 +55,11 @@ else:
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
_import_structure["modular_pipeline"] = [
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -101,7 +105,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
@@ -59,46 +57,35 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
return num_channels_latents
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
class Flux2KleinModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein.
A ModularPipeline for Flux2-Klein (distilled model).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinAutoBlocks"
else:
return "Flux2KleinBaseAutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
default_blocks_name = "Flux2KleinAutoBlocks"
@property
def requires_unconditional_embeds(self):
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
return False
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
"""
A ModularPipeline for Flux2-Klein (base model).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2KleinBaseAutoBlocks"
@property
def requires_unconditional_embeds(self):

View File

@@ -52,19 +52,61 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# map regular pipeline to modular pipeline class name
def _create_default_map_fn(pipeline_class_name: str):
"""Create a mapping function that always returns the same pipeline class."""
def _map_fn(config_dict=None):
return pipeline_class_name
return _map_fn
def _flux2_klein_map_fn(config_dict=None):
if config_dict is None:
return "Flux2KleinModularPipeline"
if "is_distilled" in config_dict and config_dict["is_distilled"]:
return "Flux2KleinModularPipeline"
else:
return "Flux2KleinBaseModularPipeline"
def _wan_map_fn(config_dict=None):
if config_dict is None:
return "WanModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22ModularPipeline"
else:
return "WanModularPipeline"
def _wan_i2v_map_fn(config_dict=None):
if config_dict is None:
return "WanImage2VideoModularPipeline"
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22Image2VideoModularPipeline"
else:
return "WanImage2VideoModularPipeline"
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("flux2-klein", "Flux2KleinModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
("wan", _wan_map_fn),
("wan-i2v", _wan_i2v_map_fn),
("flux", _create_default_map_fn("FluxModularPipeline")),
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
("flux2-klein", _flux2_klein_map_fn),
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
("z-image", _create_default_map_fn("ZImageModularPipeline")),
]
)
@@ -366,7 +408,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
"""
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn()
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
@@ -1545,7 +1588,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if modular_config_dict is not None:
blocks_class_name = modular_config_dict.get("_blocks_class_name")
else:
blocks_class_name = self.get_default_blocks_name(config_dict)
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1617,9 +1660,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default
return params
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name
@classmethod
def _load_pipeline_config(
cls,
@@ -1715,7 +1755,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
pipeline_class_name = map_fn(config_dict)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:

View File

@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"Wan22AutoBlocks",
"WanAutoBlocks",
"WanAutoImageEncoderStep",
"WanAutoVaeImageEncoderStep",
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Wan22Image2VideoModularPipeline",
"Wan22ModularPipeline",
"WanImage2VideoModularPipeline",
"WanModularPipeline",
]
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
Wan22AutoBlocks,
WanAutoBlocks,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
from .modular_blocks_wan import WanBlocks
from .modular_blocks_wan22 import Wan22Blocks
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
from .modular_pipeline import (
Wan22Image2VideoModularPipeline,
Wan22ModularPipeline,
WanImage2VideoModularPipeline,
WanModularPipeline,
)
from .modular_pipeline import WanModularPipeline
else:
import sys

View File

@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
def __init__(
self,
image_latent_inputs: List[str] = ["first_frame_latents"],
image_latent_inputs: List[str] = ["image_condition_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
@@ -294,20 +294,16 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["first_frame_latents"].
a single string or list of strings. Defaults to ["image_condition_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
# Configure to process multiple image latent inputs
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
@@ -557,81 +553,3 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.first_last_frame_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanImageVaeDecoderStep(ModularPipelineBlocks):
class WanVaeDecoderStep(ModularPipelineBlocks):
model_name = "wan"
@property

View File

@@ -89,52 +89,10 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_frame_latents",
"image_condition_latents",
required=True,
type_hint=torch.Tensor,
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
block_state.dtype
)
return components, block_state
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `WanDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"first_last_frame_latents",
required=True,
type_hint=torch.Tensor,
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
),
InputParam(
"dtype",
@@ -147,7 +105,7 @@ class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
block_state.latent_model_input = torch.cat(
[block_state.latents, block_state.first_last_frame_latents], dim=1
[block_state.latents, block_state.image_condition_latents], dim=1
).to(block_state.dtype)
return components, block_state
@@ -584,29 +542,3 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
" - `WanLoopAfterDenoiser`\n"
"This block supports image-to-video tasks for Wan2.2."
)
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
block_classes = [
WanFLF2VLoopBeforeDenoiser,
WanLoopDenoiser(
guider_input_fields={
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
"encoder_hidden_states_image": "image_embeds",
}
),
WanLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `WanFLF2VLoopBeforeDenoiser`\n"
" - `WanLoopDenoiser`\n"
" - `WanLoopAfterDenoiser`\n"
"This block supports FLF2V tasks for wan2.1."
)

View File

@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
return components, state
class WanVaeImageEncoderStep(ModularPipelineBlocks):
class WanVaeEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -493,7 +493,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("generator"),
]
@@ -564,7 +564,51 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
return components, state
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
model_name = "wan"
@property
@@ -590,7 +634,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
InputParam("height"),
InputParam("width"),
InputParam("num_frames"),
InputParam("num_frames", type_hint=int, default=81),
InputParam("generator"),
]
@@ -667,3 +711,49 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int, required=True),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.image_condition_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,474 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanImageVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
Wan22Image2VideoDenoiseStep,
WanDenoiseStep,
WanFLF2VDenoiseStep,
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeImageEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanTextEncoderStep,
WanVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# wan2.1
# wan2.1: text2vid
class WanCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: image2video
## image encoder
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
## vae encoder
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_encoder"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
## denoise
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: FLF2v
## image encoder
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
## vae encoder
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
## denoise
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanFLF2VDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_last_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: auto blocks
## image encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## denoise
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanFLF2VCoreDenoiseStep,
WanImage2VideoCoreDenoiseStep,
WanCoreDenoiseStep,
]
block_names = ["flf2v", "image2video", "text2video"]
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
)
# auto pipeline blocks
class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
WanAutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# wan22
# wan2.2: text2vid
## denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# wan2.2: image2video
## denoise
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
Wan22Image2VideoCoreDenoiseStep,
Wan22CoreDenoiseStep,
]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
)
class Wan22AutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoVaeImageEncoderStep,
Wan22AutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# presets for wan2.1 and wan2.2
# YiYi Notes: should we move these to doc?
# wan2.1
TEXT2VIDEO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
("denoise", WanImage2VideoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
FLF2V_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
("denoise", WanFLF2VDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# wan2.2 presets
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# presets all blocks (wan and wan22)
ALL_BLOCKS = {
"wan2.1": {
"text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"flf2v": FLF2V_BLOCKS,
"auto": AUTO_BLOCKS,
},
"wan2.2": {
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
"auto": AUTO_BLOCKS_WAN22,
},
}

View File

@@ -0,0 +1,83 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanDenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class WanCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanDenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 2. BLOCKS (Wan2.1 text2video)
# ====================
class WanBlocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
WanCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = ["text_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Modular pipeline blocks for Wan2.1.\n"
+ "- `WanTextEncoderStep` is used to encode the text\n"
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
)

View File

@@ -0,0 +1,88 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
)
from .encoders import (
WanTextEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 2. BLOCKS (Wan2.2 text2video)
# ====================
class Wan22Blocks(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [
WanTextEncoderStep,
Wan22CoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for text-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -0,0 +1,117 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
Wan22Image2VideoDenoiseStep,
)
from .encoders import (
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. VAE ENCODER
# ====================
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# ====================
# 2. DENOISE
# ====================
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
# 3. BLOCKS (Wan2.2 Image2Video)
# ====================
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanImage2VideoVaeEncoderStep,
Wan22Image2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Modular pipeline for image-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -0,0 +1,203 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanVaeDecoderStep
from .denoise import (
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanTextEncoderStep,
WanVaeEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. IMAGE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
# wan2.1 Auto Image Encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
model_name = "wan-i2v"
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 2. VAE ENCODER
# ====================
# wan2.1 I2V (first frame only)
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
# wan2.1 FLF2V (first and last frame)
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanImageResizeStep,
WanImageCropResizeStep,
WanFirstLastFrameVaeEncoderStep,
WanPrepareFirstLastFrameLatentsStep,
]
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
# wan2.1 Auto Vae Encoder
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
model_name = "wan-i2v"
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
# ====================
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
# ====================
# wan2.1 I2V core denoise (support both I2V and FLF2V)
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# ====================
# 4. BLOCKS (Wan2.1 Image2Video)
# ====================
# wan2.1 Image2Video Auto Blocks
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
WanAutoImageEncoderStep,
WanAutoVaeEncoderStep,
WanImage2VideoCoreDenoiseStep,
WanVaeDecoderStep,
]
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for image-to-video using Wan.\n"
+ "- for I2V workflow, all you need to provide is `image`"
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
)

View File

@@ -13,8 +13,6 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
@@ -30,19 +28,12 @@ class WanModularPipeline(
WanLoraLoaderMixin,
):
"""
A ModularPipeline for Wan.
A ModularPipeline for Wan2.1 text2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanAutoBlocks"
# override the default_blocks_name in base class, which is just return self.default_blocks_name
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22AutoBlocks"
else:
return "WanAutoBlocks"
default_blocks_name = "WanBlocks"
@property
def default_height(self):
@@ -118,3 +109,33 @@ class WanModularPipeline(
if hasattr(self, "scheduler") and self.scheduler is not None:
num_train_timesteps = self.scheduler.config.num_train_timesteps
return num_train_timesteps
class WanImage2VideoModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanImage2VideoAutoBlocks"
class Wan22ModularPipeline(WanModularPipeline):
"""
A ModularPipeline for Wan2.2 text2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Blocks"
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
"""
A ModularPipeline for Wan2.2 image2video.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Wan22Image2VideoBlocks"

View File

@@ -246,7 +246,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanImageToVideoPipeline),
("wan-i2v", WanImageToVideoPipeline),
]
)

View File

@@ -47,6 +47,21 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2KleinModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -287,7 +302,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22AutoBlocks(metaclass=DummyObject):
class Wan22Blocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
@@ -302,7 +317,82 @@ class Wan22AutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class WanAutoBlocks(metaclass=DummyObject):
class Wan22Image2VideoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Wan22ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanImage2VideoModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):

View File

@@ -32,22 +32,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
config.addinivalue_line("markers", "training: marks tests for training functionality")
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
config.addinivalue_line("markers", "slow: mark test as slow")
config.addinivalue_line("markers", "nightly: mark test as nightly")

View File

@@ -1,79 +0,0 @@
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFConfigMixin,
GGUFTesterMixin,
ModelOptCompileTesterMixin,
ModelOptConfigMixin,
ModelOptTesterMixin,
QuantizationCompileTesterMixin,
QuantizationTesterMixin,
QuantoCompileTesterMixin,
QuantoConfigMixin,
QuantoTesterMixin,
TorchAoCompileTesterMixin,
TorchAoConfigMixin,
TorchAoTesterMixin,
)
from .single_file import SingleFileTesterMixin
from .training import TrainingTesterMixin
__all__ = [
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",
"BitsAndBytesConfigMixin",
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
"GroupOffloadTesterMixin",
"IPAdapterTesterMixin",
"LayerwiseCastingTesterMixin",
"LoraHotSwappingForModelTesterMixin",
"LoraTesterMixin",
"MemoryTesterMixin",
"ModelOptCompileTesterMixin",
"ModelOptConfigMixin",
"ModelOptTesterMixin",
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",
"QuantoConfigMixin",
"QuantoTesterMixin",
"SingleFileTesterMixin",
"TorchAoCompileTesterMixin",
"TorchAoConfigMixin",
"TorchAoTesterMixin",
"TorchCompileTesterMixin",
"TrainingTesterMixin",
]

View File

@@ -1,181 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import (
AttnProcessor,
)
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
@is_attention
class AttentionTesterMixin:
"""
Mixin class for testing attention processor and module functionality on models.
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
if not hasattr(model, "fuse_qkv_projections"):
pytest.skip("Model does not support QKV projection fusion.")
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
model.fuse_qkv_projections()
has_fused_projections = False
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
has_fused_projections = True
assert module.fused_projections, "fused_projections flag should be True"
break
if has_fused_projections:
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_fusion,
atol=atol,
rtol=rtol,
msg="Output should not change after fusing projections",
)
model.unfuse_qkv_projections()
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
assert not module.fused_projections, "fused_projections flag should be False"
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
output_before_fusion,
output_after_unfusion,
atol=atol,
rtol=rtol,
msg="Output should match original after unfusing projections",
)
def test_get_set_processor(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
# Check if model has attention processors
if not hasattr(model, "attn_processors"):
pytest.skip("Model does not have attention processors.")
# Test getting processors
processors = model.attn_processors
assert isinstance(processors, dict), "attn_processors should return a dict"
assert len(processors) > 0, "Model should have at least one attention processor"
# Test that all processors can be retrieved via get_processor
for module in model.modules():
if isinstance(module, AttentionModuleMixin):
processor = module.get_processor()
assert processor is not None, "get_processor should return a processor"
# Test setting a new processor
new_processor = AttnProcessor()
module.set_processor(new_processor)
retrieved_processor = module.get_processor()
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
def test_attention_processor_dict(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict of new processors
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
# Set processors using dict
model.set_attn_processor(new_processors)
# Verify all processors were set
updated_processors = model.attn_processors
for key in current_processors.keys():
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
def test_attention_processor_count_mismatch_raises_error(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
if not hasattr(model, "set_attn_processor"):
pytest.skip("Model does not support setting attention processors.")
# Get current processors
current_processors = model.attn_processors
# Create a dict with wrong number of processors
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
# Verify error is raised
with pytest.raises(ValueError) as exc_info:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

View File

@@ -1,556 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
def require_cache_mixin(func):
"""Decorator to skip tests if model doesn't use CacheMixin."""
def wrapper(self, *args, **kwargs):
if not issubclass(self.model_class, CacheMixin):
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
return func(self, *args, **kwargs)
return wrapper
class CacheTesterMixin:
"""
Base mixin class providing common test implementations for cache testing.
Cache-specific mixins should:
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
2. Inherit from this mixin
3. Define the cache config to use for tests
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods in test classes:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Optional overrides:
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
"""
@property
def cache_input_key(self):
return "hidden_states"
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def _get_cache_config(self):
"""
Get the cache config for testing.
Should be implemented by subclasses.
"""
raise NotImplementedError("Subclass must implement _get_cache_config")
def _get_hook_names(self):
"""
Get the hook names to check for this cache type.
Should be implemented by subclasses.
Returns a list of hook name strings.
"""
raise NotImplementedError("Subclass must implement _get_hook_names")
def _test_cache_enable_disable_state(self):
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Initially cache should not be enabled
assert not model.is_cache_enabled, "Cache should not be enabled initially."
config = self._get_cache_config()
# Enable cache
model.enable_cache(config)
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
# Disable cache
model.disable_cache()
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
def _test_cache_double_enable_raises_error(self):
"""Test that enabling cache twice raises an error."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
model.enable_cache(config)
# Trying to enable again should raise ValueError
with pytest.raises(ValueError, match="Caching has already been enabled"):
model.enable_cache(config)
# Cleanup
model.disable_cache()
def _test_cache_hooks_registered(self):
"""Test that cache hooks are properly registered and removed."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
config = self._get_cache_config()
hook_names = self._get_hook_names()
model.enable_cache(config)
# Check that at least one hook was registered
hook_count = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count += 1
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
# Disable and verify hooks are removed
model.disable_cache()
hook_count_after = 0
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
for hook_name in hook_names:
hook = module._diffusers_hook.get_hook(hook_name)
if hook is not None:
hook_count_after += 1
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with cache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass (vary input tensor to simulate denoising)
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs (produces approximated output)
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_cache_context_manager(self, atol=1e-5, rtol=0):
"""Test the cache_context context manager properly isolates cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# Run inference in first context
with model.cache_context("context_1"):
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
# Run same inference in second context (cache should be reset)
with model.cache_context("context_2"):
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
# Both contexts should produce the same output (first pass in each)
assert_tensors_close(
output_ctx1,
output_ctx2,
atol=atol,
rtol=rtol,
msg="First pass in different cache contexts should produce the same output.",
)
model.disable_cache()
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the cache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@is_cache
class PyramidAttentionBroadcastConfigMixin:
"""
Base mixin providing PyramidAttentionBroadcast cache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default PAB config - can be overridden by subclasses
PAB_CONFIG = {
"spatial_attention_block_skip_range": 2,
}
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
_current_timestep = 500
def _get_cache_config(self):
config_kwargs = self.PAB_CONFIG.copy()
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
return PyramidAttentionBroadcastConfig(**config_kwargs)
def _get_hook_names(self):
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
@is_cache
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
"""
Mixin class for testing PyramidAttentionBroadcast caching on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_pab_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_pab_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_pab_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_pab_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_pab_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_pab_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FirstBlockCacheConfigMixin:
"""
Base mixin providing FirstBlockCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FBC config - can be overridden by subclasses
# Higher threshold makes FBC more aggressive about caching (skips more often)
FBC_CONFIG = {
"threshold": 1.0,
}
def _get_cache_config(self):
return FirstBlockCacheConfig(**self.FBC_CONFIG)
def _get_hook_names(self):
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
@is_cache
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FirstBlockCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# FBC requires cache_context to be set for inference
with model.cache_context("fbc_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass - FBC should skip remaining blocks and use cached residuals
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
with model.cache_context("fbc_test"):
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_fbc_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_fbc_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_fbc_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_fbc_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_fbc_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_fbc_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class FasterCacheConfigMixin:
"""
Base mixin providing FasterCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default FasterCache config - can be overridden by subclasses
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
}
def _get_cache_config(self, current_timestep_callback=None):
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
if current_timestep_callback is None:
current_timestep_callback = lambda: 1000 # noqa: E731
config_kwargs["current_timestep_callback"] = current_timestep_callback
return FasterCacheConfig(**config_kwargs)
def _get_hook_names(self):
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
@is_cache
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing FasterCache on models.
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
and timestep management. Inference tests are skipped at model level - FasterCache should
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with FasterCache enabled."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
current_timestep = [1000]
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
model.enable_cache(config)
# First pass with timestep outside skip range - computes and populates cache
current_timestep[0] = 1000
_ = model(**inputs_dict, return_dict=False)[0]
# Move timestep inside skip range so subsequent passes use cache
current_timestep[0] = 500
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass uses cached attention with different inputs
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the FasterCache state."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_faster_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_faster_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_faster_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_faster_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_faster_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -1,666 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from collections import defaultdict
from typing import Any, Dict, Optional, Type
import pytest
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
from ...testing_utils import assert_tensors_close, torch_device
def named_persistent_module_tensors(
module: nn.Module,
recurse: bool = False,
):
"""
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
Args:
module (`torch.nn.Module`):
The module we want the tensors on.
recurse (`bool`, *optional`, defaults to `False`):
Whether or not to go look in every submodule or just return the direct parameters and buffers.
"""
yield from module.named_parameters(recurse=recurse)
for named_buffer in module.named_buffers(recurse=recurse):
name, _ = named_buffer
# Get parent by splitting on dots and traversing the model
parent = module
if "." in name:
parent_name = name.rsplit(".", 1)[0]
for part in parent_name.split("."):
parent = getattr(parent, part)
name = name.split(".")[-1]
if name not in parent._non_persistent_buffers_set:
yield named_buffer
def compute_module_persistent_sizes(
model: nn.Module,
dtype: str | torch.device | None = None,
special_dtypes: dict[str, str | torch.device] | None = None,
):
"""
Compute the size of each submodule of a given model (parameters + persistent buffers).
"""
if dtype is not None:
dtype = _get_proper_dtype(dtype)
dtype_size = dtype_byte_size(dtype)
if special_dtypes is not None:
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
module_sizes = defaultdict(int)
module_list = []
module_list = named_persistent_module_tensors(model, recurse=True)
for name, tensor in module_list:
if special_dtypes is not None and name in special_dtypes:
size = tensor.numel() * special_dtypes_size[name]
elif dtype is None:
size = tensor.numel() * dtype_byte_size(tensor.dtype)
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
# According to the code in set_module_tensor_to_device, these types won't be converted
# so use their original size here
size = tensor.numel() * dtype_byte_size(tensor.dtype)
else:
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
name_parts = name.split(".")
for idx in range(len(name_parts) + 1):
module_sizes[".".join(name_parts[:idx])] += size
return module_sizes
def calculate_expected_num_shards(index_map_path):
"""
Calculate expected number of shards from index file.
Args:
index_map_path: Path to the sharded checkpoint index file
Returns:
int: Expected number of shards
"""
with open(index_map_path) as f:
weight_map_dict = json.load(f)["weight_map"]
first_key = list(weight_map_dict.keys())[0]
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
return expected_num_shards
def check_device_map_is_respected(model, device_map):
for param_name, param in model.named_parameters():
# Find device in device_map
while len(param_name) > 0 and param_name not in device_map:
param_name = ".".join(param_name.split(".")[:-1])
if param_name not in device_map:
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
param_device = device_map[param_name]
if param_device in ["cpu", "disk"]:
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
else:
assert param.device == torch.device(param_device), (
f"Expected device {param_device} for {param_name}, got {param.device}"
)
def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
if torch.is_tensor(inputs):
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
if isinstance(inputs, dict):
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
if isinstance(inputs, list):
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
return inputs
class BaseModelTesterConfig:
"""
Base class defining the configuration interface for model testing.
This class defines the contract that all model test classes must implement.
It provides a consistent interface for accessing model configuration, initialization
parameters, and test inputs across all testing mixins.
Required properties (must be implemented by subclasses):
- model_class: The model class to test
Optional properties (can be overridden, have sensible defaults):
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
- output_shape: Expected output shape for output validation tests (default: None)
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
Required methods (must be implemented by subclasses):
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example usage:
class MyModelTestConfig(BaseModelTesterConfig):
@property
def model_class(self):
return MyModel
@property
def pretrained_model_name_or_path(self):
return "org/my-model"
@property
def output_shape(self):
return (1, 3, 32, 32)
def get_init_dict(self):
return {"in_channels": 3, "out_channels": 3}
def get_dummy_inputs(self):
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
pass
"""
# ==================== Required Properties ====================
@property
def model_class(self) -> Type[nn.Module]:
"""The model class to test. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `model_class` property.")
# ==================== Optional Properties ====================
@property
def pretrained_model_name_or_path(self) -> Optional[str]:
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
return None
@property
def pretrained_model_kwargs(self) -> Dict[str, Any]:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
return None
@property
def model_split_percents(self) -> list:
"""Percentages for model parallelism tests."""
return [0.9]
# ==================== Required Methods ====================
def get_init_dict(self) -> Dict[str, Any]:
"""
Returns dict of arguments to initialize the model.
Returns:
Dict[str, Any]: Initialization arguments for the model constructor.
Example:
return {
"in_channels": 3,
"out_channels": 3,
"sample_size": 32,
}
"""
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
def get_dummy_inputs(self) -> Dict[str, Any]:
"""
Returns dict of inputs to pass to the model forward pass.
Returns:
Dict[str, Any]: Input tensors/values for model.forward().
Example:
return {
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
"timestep": torch.tensor([1], device=torch_device),
}
"""
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
class ModelTesterMixin:
"""
Base mixin class for model testing with common test methods.
This mixin expects the test class to also inherit from BaseModelTesterConfig
(or implement its interface) which provides:
- model_class: The model class to test
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Example:
class MyModelTestConfig(BaseModelTesterConfig):
model_class = MyModel
def get_init_dict(self): ...
def get_dummy_inputs(self): ...
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
pass
"""
@torch.no_grad()
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
torch.manual_seed(0)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path)
new_model.to(torch_device)
for param_name in model.state_dict().keys():
param_1 = model.state_dict()[param_name]
param_2 = new_model.state_dict()[param_name]
assert param_1.shape == param_2.shape, (
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@torch.no_grad()
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
model.save_pretrained(tmp_path, variant="fp16")
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
with pytest.raises(OSError) as exc_info:
self.model_class.from_pretrained(tmp_path)
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
if torch_device == "mps" and dtype == torch.bfloat16:
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
model.to(dtype)
model.save_pretrained(tmp_path)
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
assert new_model.dtype == dtype
@torch.no_grad()
def test_determinism(self, atol=1e-5, rtol=0):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
first_filtered = first_flat[mask]
second_filtered = second_flat[mask]
assert_tensors_close(
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
)
@torch.no_grad()
def test_output(self, expected_output_shape=None):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None, "Model output is None"
assert output[0].shape == expected_output_shape or self.output_shape, (
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
)
@torch.no_grad()
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
def set_nan_tensor_to_zero(t):
device = t.device
if device.type == "mps":
t = t.to("cpu")
t[t != t] = 0
return t.to(device)
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (list, tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
assert_tensors_close(
set_nan_tensor_to_zero(tuple_object),
set_nan_tensor_to_zero(dict_object),
atol=atol,
rtol=rtol,
msg="Tuple and dict output are not equal",
)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_getattr_is_correct(self, caplog):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
# save some things to test
model.dummy_attribute = 5
model.register_to_config(test_attribute=5)
logger_name = "diffusers.models.modeling_utils"
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "dummy_attribute")
assert getattr(model, "dummy_attribute") == 5
assert model.dummy_attribute == 5
# no warning should be thrown
assert caplog.text == ""
with caplog.at_level(logging.WARNING, logger=logger_name):
caplog.clear()
assert hasattr(model, "save_pretrained")
fn = model.save_pretrained
fn_1 = getattr(model, "save_pretrained")
assert fn == fn_1
# no warning should be thrown
assert caplog.text == ""
# warning should be thrown for config attributes accessed directly
with pytest.warns(FutureWarning):
assert model.test_attribute == 5
with pytest.warns(FutureWarning):
assert getattr(model, "test_attribute") == 5
with pytest.raises(AttributeError) as error:
model.does_not_exist
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
else:
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
@require_accelerator
@pytest.mark.skipif(
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be use for inference with an accelerator",
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
model.to(dtype).save_pretrained(tmp_path)
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
for name, param in model_loaded.named_parameters():
if fp32_modules and any(
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
):
assert param.data.dtype == torch.float32
else:
assert param.data.dtype == dtype
inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
)
@require_accelerator
@torch.no_grad()
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
variant = "fp16"
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
f"Variant index file {index_filename} should exist"
)
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
)
@torch.no_grad()
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
from diffusers.utils import constants
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_persistent_sizes(model)[""]
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
# Save original values to restore after test
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
try:
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
# Check if the right number of shards exists
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
assert actual_num_shards == expected_num_shards, (
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
)
# Load without parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = False
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
model_sequential = model_sequential.to(torch_device)
# Load with parallel loading
constants.HF_ENABLE_PARALLEL_LOADING = True
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
torch.manual_seed(0)
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
)
finally:
# Restore original values
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
if original_parallel_workers is not None:
constants.HF_PARALLEL_WORKERS = original_parallel_workers
@require_torch_multi_accelerator
@torch.no_grad()
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict, return_dict=False)[0]
model_size = compute_module_sizes(model)[""]
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(tmp_path)
for max_size in max_gpu_sizes:
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
# Making sure part of the model will be on GPU 0 and GPU 1
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
)

View File

@@ -1,166 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import pytest
import torch
from ...testing_utils import (
backend_empty_cache,
is_torch_compile,
require_accelerator,
require_torch_version_greater,
torch_device,
)
@is_torch_compile
@require_accelerator
@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
"""
Mixin class for testing torch.compile functionality on models.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: compile
Use `pytest -m "not compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic shape testing."""
return None
def setup_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
def test_torch_compile_recompilation_and_graph_break(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_with_group_offloading(self):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch._dynamo.config.cache_size_limit = 10000
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.eval()
group_offload_kwargs = {
"onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
"use_stream": True,
"non_blocking": True,
}
model.enable_group_offload(**group_offload_kwargs)
model.compile()
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_on_different_shapes(self):
if self.different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
torch.fx.experimental._config.use_duck_shape = False
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model = torch.compile(model, fullgraph=True, dynamic=True)
for height, width in self.different_shapes_for_compilation:
with torch._dynamo.config.patch(error_on_recompile=True):
inputs_dict = self.get_dummy_inputs(height=height, width=width)
_ = model(**inputs_dict)
@torch.no_grad()
def test_compile_works_with_aot(self, tmp_path):
from torch._inductor.package import load_package
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
assert os.path.exists(package_path), f"Package file not created at {package_path}"
loaded_binary = load_package(package_path, run_single_threaded=True)
model.forward = loaded_binary
_ = model(**inputs_dict)
_ = model(**inputs_dict)

View File

@@ -1,158 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
"""
Check if IP Adapter processors are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if IP Adapter is correctly set, False otherwise
"""
for module in model.attn_processors.values():
if isinstance(module, processor_cls):
return True
return False
@is_ip_adapter
class IPAdapterTesterMixin:
"""
Mixin class for testing IP Adapter functionality on models.
Expected from config mixin:
- model_class: The model class to test
Required properties (must be implemented by subclasses):
- ip_adapter_processor_cls: The IP Adapter processor class to use
Required methods (must be implemented by subclasses):
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: ip_adapter
Use `pytest -m "not ip_adapter"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@property
def ip_adapter_processor_cls(self):
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
def create_ip_adapter_state_dict(self, model):
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
@torch.no_grad()
def test_load_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter processors not set correctly"
)
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
"Output should differ with IP Adapter enabled"
)
@pytest.mark.skip(
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
)
def test_ip_adapter_scale(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
# Test scale = 0.0 (no effect)
model.set_ip_adapter_scale(0.0)
torch.manual_seed(0)
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Test scale = 1.0 (full effect)
model.set_ip_adapter_scale(1.0)
torch.manual_seed(0)
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
# Outputs should differ with different scales
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
"Output should differ with different IP Adapter scales"
)
@pytest.mark.skip(
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
)
def test_unload_ip_adapter(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
# Save original processors
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
# Create and load IP adapter
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
model._load_ip_adapter_weights([ip_adapter_state_dict])
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
# Unload IP adapter
model.unload_ip_adapter()
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
"IP Adapter should be unloaded"
)
# Verify processors are restored
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
assert original_processors == current_processors, "Processors should be restored after unload"

View File

@@ -1,555 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import json
import os
import re
import pytest
import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_lora,
is_torch_compile,
require_peft_backend,
require_peft_version_greater,
require_torch_accelerator,
require_torch_version_greater,
torch_device,
)
if is_peft_available():
from diffusers.loaders.peft import PeftAdapterMixin
def check_if_lora_correctly_set(model) -> bool:
"""
Check if LoRA layers are correctly set in the model.
Args:
model: The model to check
Returns:
bool: True if LoRA is correctly set, False otherwise
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
@is_lora
@require_peft_backend
class LoraTesterMixin:
"""
Mixin class for testing LoRA/PEFT functionality on models.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: lora
Use `pytest -m "not lora"` to skip these tests
"""
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
@torch.no_grad()
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
torch.manual_seed(0)
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
model.save_lora_adapter(tmp_path)
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
"LoRA weights file not created"
)
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
for k in state_dict_loaded:
loaded_v = state_dict_loaded[k]
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
torch.manual_seed(0)
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
"Output should differ with LoRA enabled"
)
assert_tensors_close(
outputs_with_lora,
outputs_with_lora_2,
atol=atol,
rtol=rtol,
msg="Outputs should match before and after save/load",
)
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
wrong_name = "foo"
with pytest.raises(ValueError) as exc_info:
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
from peft import LoraConfig
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
model.save_lora_adapter(tmp_path)
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
assert os.path.isfile(model_file), "LoRA weights file not created"
# Perturb the metadata in the state dict
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
with pytest.raises(TypeError) as exc_info:
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
@is_lora
@is_torch_compile
@require_peft_backend
@require_peft_version_greater("0.14.0")
@require_torch_version_greater("2.7.1")
@require_torch_accelerator
class LoraHotSwappingForModelTesterMixin:
"""
Mixin class for testing LoRA hot swapping functionality on models.
Test that hotswapping does not result in recompilation on the model directly.
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
and is extensively tested there. The goal of this test is specifically to ensure that
hotswapping with diffusers does not require recompilation.
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
for the analogous PEFT test.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest marks: lora, torch_compile
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
"""
@property
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
"""Optional list of (height, width) tuples for dynamic compilation tests."""
return None
def setup_method(self):
if not issubclass(self.model_class, PeftAdapterMixin):
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
def teardown_method(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
from peft import LoraConfig
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
target_modules=target_modules,
init_lora_weights=False,
use_dora=False,
)
return lora_config
def _get_linear_module_name_other_than_attn(self, model):
linear_names = [
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
]
return linear_names[0]
def _check_model_hotswap(
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
):
"""
Check that hotswapping works on a model.
Steps:
- create 2 LoRA adapters and save them
- load the first adapter
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
- optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
alpha0, alpha1 = rank0, rank1
max_rank = max([rank0, rank1])
if target_modules1 is None:
target_modules1 = target_modules0[:]
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
model.add_adapter(lora_config0, adapter_name="adapter0")
with torch.inference_mode():
torch.manual_seed(0)
output0_before = model(**inputs_dict)["sample"]
model.add_adapter(lora_config1, adapter_name="adapter1")
model.set_adapter("adapter1")
with torch.inference_mode():
torch.manual_seed(0)
output1_before = model(**inputs_dict)["sample"]
# sanity checks:
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
assert not (output0_before == 0).all()
assert not (output1_before == 0).all()
# save the adapter checkpoints
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
del model
# load the first adapter
torch.manual_seed(0)
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
if do_compile or (rank0 != rank1):
# no need to prepare if the model is not compiled or if the ranks are identical
model.enable_lora_hotswap(target_rank=max_rank)
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
# additionally check if dynamic compilation works.
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output0_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
if different_shapes is not None:
for height, width in different_shapes:
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
_ = model(**new_inputs_dict)
else:
output1_after = model(**inputs_dict)["sample"]
assert_tensors_close(
output1_before,
output1_after,
atol=atol,
rtol=rtol,
msg="Output mismatch after hotswapping to adapter1",
)
# check error when not passing valid adapter name
name = "does-not-exist"
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
with pytest.raises(ValueError, match=re.escape(msg)):
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_model(self, tmp_path, rank0, rank1):
self._check_model_hotswap(
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
# block.
target_modules = ["to_q"]
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
target_modules.append(self._get_linear_module_name_other_than_attn(model))
del model
# It's important to add this context to raise an error on recompilation
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
)
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
# ensure that enable_lora_hotswap is called before loading the first adapter
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
with pytest.raises(ValueError, match=msg):
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
different_shapes_for_compilation = self.different_shapes_for_compilation
if different_shapes_for_compilation is None:
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
# variable to represent input sizes that are the same. For more details,
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
torch.fx.experimental._config.use_duck_shape = False
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
with torch._dynamo.config.patch(error_on_recompile=True):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=rank0,
rank1=rank1,
target_modules0=target_modules,
)

View File

@@ -1,498 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import glob
import inspect
from functools import wraps
import pytest
import torch
from accelerate.utils.modeling import compute_module_sizes
from diffusers.utils.testing_utils import _check_safetensors_serialization
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
is_cpu_offload,
is_group_offload,
is_memory,
require_accelerator,
torch_device,
)
from .common import cast_inputs_to_dtype, check_device_map_is_respected
def require_offload_support(func):
"""
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
return func(self, *args, **kwargs)
return wrapper
def require_group_offload_support(func):
"""
Decorator to skip tests if model doesn't support group offloading.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
return func(self, *args, **kwargs)
return wrapper
@is_cpu_offload
class CPUOffloadTesterMixin:
"""
Mixin class for testing CPU offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cpu_offload
Use `pytest -m "not cpu_offload"` to skip these tests
"""
@property
def model_split_percents(self) -> list[float]:
"""List of percentages for splitting model across devices during offloading tests."""
return [0.5, 0.7]
@require_offload_support
@torch.no_grad()
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
model.cpu().save_pretrained(str(tmp_path))
for max_size in max_gpu_sizes:
max_memory = {0: max_size, "cpu": model_size * 2}
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
# Force disk offload by setting very small CPU memory
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
# This errors out because it's missing an offload folder
with pytest.raises(ValueError):
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
)
@require_offload_support
@torch.no_grad()
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
model_size = compute_module_sizes(model)[""]
model.cpu().save_pretrained(str(tmp_path))
max_size = int(self.model_split_percents[0] * model_size)
max_memory = {0: max_size, "cpu": max_size}
new_model = self.model_class.from_pretrained(
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
)
check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
assert_tensors_close(
base_output[0],
new_output[0],
atol=atol,
rtol=rtol,
msg="Output should match with disk offloading (safetensors)",
)
@is_group_offload
class GroupOffloadTesterMixin:
"""
Mixin class for testing group offloading functionality.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: group_offload
Use `pytest -m "not group_offload"` to skip these tests
"""
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
@torch.no_grad()
def run_forward(model):
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
), "Group offloading hook should be set"
model.eval()
return model(**inputs_dict)[0]
model = self.model_class(**init_dict)
model.to(torch_device)
output_without_group_offloading = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
output_with_group_offloading1 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
output_with_group_offloading2 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(torch_device, offload_type="leaf_level")
output_with_group_offloading3 = run_forward(model)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.enable_group_offload(
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
)
output_with_group_offloading4 = run_forward(model)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading1,
atol=atol,
rtol=rtol,
msg="Output should match with block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading2,
atol=atol,
rtol=rtol,
msg="Output should match with non-blocking block-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading3,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading",
)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading4,
atol=atol,
rtol=rtol,
msg="Output should match with leaf-level offloading with stream",
)
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
_ = model(**inputs_dict)[0]
torch.manual_seed(0)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
storage_dtype, compute_dtype = torch.float16, torch.float32
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**init_dict)
model.eval()
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
model.enable_group_offload(
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
@require_group_offload_support
@pytest.mark.parametrize("record_stream", [False, True])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
@torch.no_grad()
@torch.inference_mode()
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
return "generator" in params
def _run_forward(model, inputs_dict):
accepts_generator = _has_generator_arg(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
return model(**inputs_dict)[0]
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
model.to(torch_device)
output_without_group_offloading = _run_forward(model, inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.eval()
num_blocks_per_group = None if offload_type == "leaf_level" else 1
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
tmpdir = str(tmp_path)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
offload_to_disk_path=tmpdir,
use_stream=True,
record_stream=record_stream,
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
assert has_safetensors, "No safetensors found in the directory."
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
# in nature. So, skip it.
if offload_type != "leaf_level":
is_correct, extra_files, missing_files = _check_safetensors_serialization(
module=model,
offload_to_disk_path=tmpdir,
offload_type=offload_type,
num_blocks_per_group=num_blocks_per_group,
)
if not is_correct:
if extra_files:
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
elif missing_files:
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
output_with_group_offloading = _run_forward(model, inputs_dict)
assert_tensors_close(
output_without_group_offloading,
output_with_group_offloading,
atol=atol,
rtol=rtol,
msg="Output should match with disk-based group offloading",
)
class LayerwiseCastingTesterMixin:
"""
Mixin class for testing layerwise dtype casting for memory optimization.
Expected from config mixin:
- model_class: The model class to test
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
"""
@torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats():
gc.collect()
backend_synchronize(torch_device)
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
def get_memory_usage(storage_dtype, compute_dtype):
torch.manual_seed(0)
config = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
model = self.model_class(**config).eval()
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
reset_memory_stats()
model(**inputs_dict)
model_memory_footprint = model.get_memory_footprint()
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
return model_memory_footprint, peak_inference_memory_allocated_mb
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
torch.float8_e4m3fn, torch.bfloat16
)
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
"Memory footprint should decrease with lower precision storage"
)
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
"Peak memory should be lower with bf16 compute on newer GPUs"
)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance.
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
assert (
fp8_e4m3_fp32_max_memory < fp32_max_memory
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
), "Peak memory should be lower or within tolerance with fp8 storage"
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
model = self.model_class(**self.get_init_dict())
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = self.get_dummy_inputs()
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict, return_dict=False)[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
@is_memory
@require_accelerator
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
"""
Combined mixin class for all memory optimization tests including CPU/disk offloading,
group offloading, and layerwise dtype casting.
This mixin inherits from:
- CPUOffloadTesterMixin: CPU and disk offloading tests
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
Expected from config mixin:
- model_class: The model class to test
Optional properties:
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: memory
Use `pytest -m "not memory"` to skip these tests
"""

View File

@@ -1,128 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
def _find_free_port():
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
"""Worker function for context parallel testing."""
try:
# Set up distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Initialize process group
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
# Set device for this process
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
# Create model
model = model_class(**init_dict)
model.to(device)
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
# Run forward pass
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
# Only rank 0 reports results
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
# Find a free port for distributed communication
master_port = _find_free_port()
# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()
# Spawn worker processes
mp.spawn(
_context_parallel_worker,
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,272 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
from ...testing_utils import (
backend_empty_cache,
is_single_file,
nightly,
require_torch_accelerator,
torch_device,
)
from .common import check_device_map_is_respected
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
"""Download a single file checkpoint from the Hub to a temporary directory."""
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
return path
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
"""Download diffusers config files (excluding weights) from a repository."""
path = snapshot_download(
pretrained_model_name_or_path,
ignore_patterns=[
"**/*.ckpt",
"*.ckpt",
"**/*.bin",
"*.bin",
"**/*.pt",
"*.pt",
"**/*.safetensors",
"*.safetensors",
],
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
local_dir=tmpdir,
)
return path
@nightly
@require_torch_accelerator
@is_single_file
class SingleFileTesterMixin:
"""
Mixin class for testing single file loading for models.
Required properties (must be implemented by subclasses):
- ckpt_path: Path or Hub path to the single file checkpoint
Optional properties:
- torch_dtype: torch dtype to use for testing (default: None)
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
Expected from config mixin:
- model_class: The model class to test
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
Pytest mark: single_file
Use `pytest -m "not single_file"` to skip these tests
"""
# ==================== Required Properties ====================
@property
def ckpt_path(self) -> str:
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
# ==================== Optional Properties ====================
@property
def torch_dtype(self) -> torch.dtype | None:
"""torch dtype to use for single file testing."""
return None
@property
def alternate_ckpt_paths(self) -> list[str] | None:
"""List of alternate checkpoint paths for variant testing."""
return None
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_model_config(self):
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs between pretrained loading and single file loading: "
f"pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_model_parameters(self):
pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs}
single_file_kwargs = {"device": torch_device}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs["torch_dtype"] = self.torch_dtype
# Load pretrained model, get state dict on CPU, then free GPU memory
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
del model
gc.collect()
backend_empty_cache(torch_device)
# Load single file model, get state dict on CPU
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
state_dict_single_file = {k: v.cpu() for k, v in model_single_file.state_dict().items()}
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
"Model parameters keys differ between pretrained and single file loading. "
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
)
for key in state_dict.keys():
param = state_dict[key]
param_single_file = state_dict_single_file[key]
assert param.shape == param_single_file.shape, (
f"Parameter shape mismatch for {key}: "
f"pretrained {param.shape} vs single file {param_single_file.shape}"
)
assert torch.equal(param, param_single_file), f"Parameter values differ for {key}"
def test_single_file_loading_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with local_files_only=True"
def test_single_file_loading_with_diffusers_config(self):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs.update(self.pretrained_model_kwargs)
# Load with config parameter
model_single_file = self.model_class.from_single_file(
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
)
# Load pretrained for comparison
pretrained_kwargs = {**self.pretrained_model_kwargs}
if self.torch_dtype:
pretrained_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
# Compare configs
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
for param_name, param_value in model_single_file.config.items():
if param_name in PARAMS_TO_IGNORE:
continue
assert model.config[param_name] == param_value, (
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
)
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
single_file_kwargs.update(self.pretrained_model_kwargs)
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
model_single_file = self.model_class.from_single_file(
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
)
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
def test_single_file_loading_dtype(self):
for dtype in [torch.float32, torch.float16]:
if torch_device == "mps" and dtype == torch.bfloat16:
continue
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
# Cleanup
del model_single_file
gc.collect()
backend_empty_cache(torch_device)
def test_checkpoint_variant_loading(self):
if not self.alternate_ckpt_paths:
return
for ckpt_path in self.alternate_ckpt_paths:
backend_empty_cache(torch_device)
single_file_kwargs = {}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
del model
gc.collect()
backend_empty_cache(torch_device)
def test_single_file_loading_with_device_map(self):
single_file_kwargs = {"device_map": torch_device}
if self.torch_dtype:
single_file_kwargs["torch_dtype"] = self.torch_dtype
model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
assert model is not None, "Failed to load model with device_map"
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map"
assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map"
check_device_map_is_respected(model, model.hf_device_map)

View File

@@ -1,220 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import pytest
import torch
from diffusers.training_utils import EMAModel
from ...testing_utils import (
backend_empty_cache,
is_training,
require_torch_accelerator_with_training,
torch_all_close,
torch_device,
)
@is_training
@require_torch_accelerator_with_training
class TrainingTesterMixin:
"""
Mixin class for testing training functionality on models.
Expected from config mixin:
- model_class: The model class to test
- output_shape: Tuple defining the expected output shape
Expected methods from config mixin:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: training
Use `pytest -m "not training"` to skip these tests
"""
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
def test_training_with_ema(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
ema_model = EMAModel(model.parameters())
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model.parameters())
def test_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
init_dict = self.get_init_dict()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
# check enable works
model.enable_gradient_checkpointing()
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
# check disable works
model.disable_gradient_checkpointing()
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
def test_gradient_checkpointing_is_applied(self, expected_set=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if expected_set is None:
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
init_dict = self.get_init_dict()
model_class_copy = copy.copy(self.model_class)
model = model_class_copy(**init_dict)
model.enable_gradient_checkpointing()
modules_with_gc_enabled = {}
for submodule in model.modules():
if hasattr(submodule, "gradient_checkpointing"):
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
modules_with_gc_enabled[submodule.__class__.__name__] = True
assert set(modules_with_gc_enabled.keys()) == expected_set, (
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
)
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
if not self.model_class._supports_gradient_checkpointing:
pytest.skip("Gradient checkpointing is not supported.")
if skip is None:
skip = set()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict_copy = copy.deepcopy(inputs_dict)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict, return_dict=False)[0]
# run the backwards pass on the model
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
torch.manual_seed(0)
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
# run the backwards pass on the model
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
assert (loss - loss_2).abs() < loss_tolerance, (
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
if "post_quant_conv" in name:
continue
if name in skip:
continue
if param.grad is None:
continue
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
f"Gradient mismatch for {name}"
)
def test_mixed_precision_training(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.train()
# Test with float16
if torch.device(torch_device).type != "cpu":
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
# Test with bfloat16
if torch.device(torch_device).type != "cpu":
model.zero_grad()
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
output = model(**inputs_dict, return_dict=False)[0]
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()

View File

@@ -13,52 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import unittest
import pytest
import torch
from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
# TODO: This standalone function maintains backward compatibility with pipeline tests
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
@@ -68,7 +39,7 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterAttnProcessor(
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
@@ -79,8 +50,11 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=(
@@ -101,45 +75,57 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
)
del sd
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return FluxTransformer2DModel
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-pipe"
def dummy_input(self):
return self.prepare_dummy_input()
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def output_shape(self) -> tuple[int, int]:
def input_shape(self):
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 4)
@property
def model_split_percents(self) -> list:
return [0.9]
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
@property
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
"""Return Flux model initialization arguments."""
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -151,40 +137,11 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
sequence_length = 48
embedding_dim = 32
inputs_dict = self.dummy_input
return init_dict, inputs_dict
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim), generator=self.generator, device=torch_device
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
def test_deprecated_inputs_img_txt_ids_3d(self):
"""Test that deprecated 3D img_ids and txt_ids still work."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
@@ -205,228 +162,63 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
with torch.no_grad():
output_2 = model(**inputs_dict).to_tuple()[0]
assert output_1.shape == output_2.shape
assert torch.allclose(output_1, output_2, atol=1e-5), (
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
"are not equal as them as 2d inputs"
self.assertEqual(output_1.shape, output_2.shape)
self.assertTrue(
torch.allclose(output_1, output_2, atol=1e-5),
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux Transformer."""
# The test exists for cases like
# https://github.com/huggingface/diffusers/issues/11874
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_exclude_modules(self):
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
lora_rank = 4
target_module = "single_transformer_blocks.0.proj_out"
adapter_name = "foo"
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""
@property
def ip_adapter_processor_cls(self):
return FluxIPAdapterAttnProcessor
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
torch.manual_seed(0)
# Create dummy image embeds for IP adapter
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
return inputs_dict
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
return create_flux_ip_adapter_state_dict(model)
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux Transformer."""
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 32
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
state_dict = model.state_dict()
target_mod_shape = state_dict[f"{target_module}.weight"].shape
lora_state_dict = {
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
}
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
config = LoraConfig(
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
)
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
sequence_length = 24
embedding_dim = 32
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
return {
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
}
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
@property
def ckpt_path(self):
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@property
def alternate_ckpt_paths(self):
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
@property
def pretrained_model_name_or_path(self):
return "black-forest-labs/FLUX.1-dev"
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-flux-transformer"
@property
def pretrained_model_kwargs(self):
return {}
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pooled_projections": randn_tensor(
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
}
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
"""Quanto + compile tests for Flux Transformer."""
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux Transformer."""
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
@property
def gguf_filename(self):
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pooled_projections": randn_tensor(
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
}
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
"""ModelOpt quantization tests for Flux Transformer."""
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
"""ModelOpt + compile tests for Flux Transformer."""
@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes")
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
"""BitsAndBytes + compile tests for Flux Transformer."""
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
"""FasterCache tests for Flux Transformer."""
# Flux is guidance distilled, so we can test at model level without CFG batch handling
FASTER_CACHE_CONFIG = {
"spatial_attention_block_skip_range": 2,
"spatial_attention_timestep_skip_range": (-1, 901),
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}
def prepare_dummy_input(self, height, width):
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)

View File

@@ -38,7 +38,6 @@ from diffusers.utils.import_utils import (
is_gguf_available,
is_kernels_available,
is_note_seq_available,
is_nvidia_modelopt_version,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
@@ -131,59 +130,6 @@ def torch_all_close(a, b, *args, **kwargs):
return True
def assert_tensors_close(
actual: "torch.Tensor",
expected: "torch.Tensor",
atol: float = 1e-5,
rtol: float = 1e-5,
msg: str = "",
) -> None:
"""
Assert that two tensors are close within tolerance.
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
Provides concise, actionable error messages without dumping full tensors.
Args:
actual: The actual tensor from the computation.
expected: The expected tensor to compare against.
atol: Absolute tolerance.
rtol: Relative tolerance.
msg: Optional message prefix for the assertion error.
Raises:
AssertionError: If tensors have different shapes or values exceed tolerance.
Example:
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
"""
if not is_torch_available():
raise ValueError("PyTorch needs to be installed to use this function.")
if actual.shape != expected.shape:
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
abs_diff = (actual - expected).abs()
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()
total = actual.numel()
raise AssertionError(
f"{msg}\n"
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
)
def numpy_cosine_similarity_distance(a, b):
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()
@@ -295,6 +241,7 @@ def parse_flag_from_env(key, default=False):
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
def floats_tensor(shape, scale=1.0, rng=None, name=None):
@@ -335,155 +282,12 @@ def nightly(test_case):
def is_torch_compile(test_case):
"""
Decorator marking a test as a torch.compile test. These tests can be filtered using:
pytest -m "not compile" to skip
pytest -m compile to run only these tests
"""
return pytest.mark.compile(test_case)
Decorator marking a test that runs compile tests in the diffusers CI.
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
def is_single_file(test_case):
"""
Decorator marking a test as a single file loading test. These tests can be filtered using:
pytest -m "not single_file" to skip
pytest -m single_file to run only these tests
"""
return pytest.mark.single_file(test_case)
def is_lora(test_case):
"""
Decorator marking a test as a LoRA test. These tests can be filtered using:
pytest -m "not lora" to skip
pytest -m lora to run only these tests
"""
return pytest.mark.lora(test_case)
def is_ip_adapter(test_case):
"""
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
pytest -m "not ip_adapter" to skip
pytest -m ip_adapter to run only these tests
"""
return pytest.mark.ip_adapter(test_case)
def is_training(test_case):
"""
Decorator marking a test as a training test. These tests can be filtered using:
pytest -m "not training" to skip
pytest -m training to run only these tests
"""
return pytest.mark.training(test_case)
def is_attention(test_case):
"""
Decorator marking a test as an attention test. These tests can be filtered using:
pytest -m "not attention" to skip
pytest -m attention to run only these tests
"""
return pytest.mark.attention(test_case)
def is_memory(test_case):
"""
Decorator marking a test as a memory optimization test. These tests can be filtered using:
pytest -m "not memory" to skip
pytest -m memory to run only these tests
"""
return pytest.mark.memory(test_case)
def is_cpu_offload(test_case):
"""
Decorator marking a test as a CPU offload test. These tests can be filtered using:
pytest -m "not cpu_offload" to skip
pytest -m cpu_offload to run only these tests
"""
return pytest.mark.cpu_offload(test_case)
def is_group_offload(test_case):
"""
Decorator marking a test as a group offload test. These tests can be filtered using:
pytest -m "not group_offload" to skip
pytest -m group_offload to run only these tests
"""
return pytest.mark.group_offload(test_case)
def is_quantization(test_case):
"""
Decorator marking a test as a quantization test. These tests can be filtered using:
pytest -m "not quantization" to skip
pytest -m quantization to run only these tests
"""
return pytest.mark.quantization(test_case)
def is_bitsandbytes(test_case):
"""
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
pytest -m "not bitsandbytes" to skip
pytest -m bitsandbytes to run only these tests
"""
return pytest.mark.bitsandbytes(test_case)
def is_quanto(test_case):
"""
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
pytest -m "not quanto" to skip
pytest -m quanto to run only these tests
"""
return pytest.mark.quanto(test_case)
def is_torchao(test_case):
"""
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
pytest -m "not torchao" to skip
pytest -m torchao to run only these tests
"""
return pytest.mark.torchao(test_case)
def is_gguf(test_case):
"""
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
pytest -m "not gguf" to skip
pytest -m gguf to run only these tests
"""
return pytest.mark.gguf(test_case)
def is_modelopt(test_case):
"""
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
pytest -m "not modelopt" to skip
pytest -m modelopt to run only these tests
"""
return pytest.mark.modelopt(test_case)
def is_context_parallel(test_case):
"""
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
pytest -m "not context_parallel" to skip
pytest -m context_parallel to run only these tests
"""
return pytest.mark.context_parallel(test_case)
def is_cache(test_case):
"""
Decorator marking a test as a cache test. These tests can be filtered using:
pytest -m "not cache" to skip
pytest -m cache to run only these tests
"""
return pytest.mark.cache(test_case)
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
def require_torch(test_case):
@@ -846,16 +650,6 @@ def require_kernels_version_greater_or_equal(kernels_version):
return decorator
def require_modelopt_version_greater_or_equal(modelopt_version):
def decorator(test_case):
return pytest.mark.skipif(
not is_nvidia_modelopt_version(">=", modelopt_version),
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
)(test_case)
return decorator
def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend

View File

@@ -1,592 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility script to generate test suites for diffusers model classes.
Usage:
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
This will analyze the model file and generate a test file with appropriate
test classes based on the model's mixins and attributes.
"""
import argparse
import ast
import sys
from pathlib import Path
MIXIN_TO_TESTER = {
"ModelMixin": "ModelTesterMixin",
"PeftAdapterMixin": "LoraTesterMixin",
}
ATTRIBUTE_TO_TESTER = {
"_cp_plan": "ContextParallelTesterMixin",
"_supports_gradient_checkpointing": "TrainingTesterMixin",
}
ALWAYS_INCLUDE_TESTERS = [
"ModelTesterMixin",
"MemoryTesterMixin",
"TorchCompileTesterMixin",
]
# Attention-related class names that indicate the model uses attention
ATTENTION_INDICATORS = {
"AttentionMixin",
"AttentionModuleMixin",
}
OPTIONAL_TESTERS = [
# Quantization testers
("BitsAndBytesTesterMixin", "bnb"),
("QuantoTesterMixin", "quanto"),
("TorchAoTesterMixin", "torchao"),
("GGUFTesterMixin", "gguf"),
("ModelOptTesterMixin", "modelopt"),
# Quantization compile testers
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
("QuantoCompileTesterMixin", "quanto_compile"),
("TorchAoCompileTesterMixin", "torchao_compile"),
("GGUFCompileTesterMixin", "gguf_compile"),
("ModelOptCompileTesterMixin", "modelopt_compile"),
# Cache testers
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
("FirstBlockCacheTesterMixin", "fbc_cache"),
("FasterCacheTesterMixin", "faster_cache"),
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
]
class ModelAnalyzer(ast.NodeVisitor):
def __init__(self):
self.model_classes = []
self.current_class = None
self.imports = set()
def visit_Import(self, node: ast.Import):
for alias in node.names:
self.imports.add(alias.name.split(".")[-1])
self.generic_visit(node)
def visit_ImportFrom(self, node: ast.ImportFrom):
for alias in node.names:
self.imports.add(alias.name)
self.generic_visit(node)
def visit_ClassDef(self, node: ast.ClassDef):
base_names = []
for base in node.bases:
if isinstance(base, ast.Name):
base_names.append(base.id)
elif isinstance(base, ast.Attribute):
base_names.append(base.attr)
if "ModelMixin" in base_names:
class_info = {
"name": node.name,
"bases": base_names,
"attributes": {},
"has_forward": False,
"init_params": [],
}
for item in node.body:
if isinstance(item, ast.Assign):
for target in item.targets:
if isinstance(target, ast.Name):
attr_name = target.id
if attr_name.startswith("_"):
class_info["attributes"][attr_name] = self._get_value(item.value)
elif isinstance(item, ast.FunctionDef):
if item.name == "forward":
class_info["has_forward"] = True
class_info["forward_params"] = self._extract_func_params(item)
elif item.name == "__init__":
class_info["init_params"] = self._extract_func_params(item)
self.model_classes.append(class_info)
self.generic_visit(node)
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
params = []
args = func_node.args
num_defaults = len(args.defaults)
num_args = len(args.args)
first_default_idx = num_args - num_defaults
for i, arg in enumerate(args.args):
if arg.arg == "self":
continue
param_info = {"name": arg.arg, "type": None, "default": None}
if arg.annotation:
param_info["type"] = self._get_annotation_str(arg.annotation)
default_idx = i - first_default_idx
if default_idx >= 0 and default_idx < len(args.defaults):
param_info["default"] = self._get_value(args.defaults[default_idx])
params.append(param_info)
return params
def _get_annotation_str(self, node) -> str:
if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Constant):
return repr(node.value)
elif isinstance(node, ast.Subscript):
base = self._get_annotation_str(node.value)
if isinstance(node.slice, ast.Tuple):
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
else:
args = self._get_annotation_str(node.slice)
return f"{base}[{args}]"
elif isinstance(node, ast.Attribute):
return f"{self._get_annotation_str(node.value)}.{node.attr}"
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
left = self._get_annotation_str(node.left)
right = self._get_annotation_str(node.right)
return f"{left} | {right}"
elif isinstance(node, ast.Tuple):
return ", ".join(self._get_annotation_str(el) for el in node.elts)
return "Any"
def _get_value(self, node):
if isinstance(node, ast.Constant):
return node.value
elif isinstance(node, ast.Name):
if node.id == "None":
return None
elif node.id == "True":
return True
elif node.id == "False":
return False
return node.id
elif isinstance(node, ast.List):
return [self._get_value(el) for el in node.elts]
elif isinstance(node, ast.Dict):
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
return "<complex>"
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
with open(filepath) as f:
source = f.read()
tree = ast.parse(source)
analyzer = ModelAnalyzer()
analyzer.visit(tree)
return analyzer.model_classes, analyzer.imports
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
testers = list(ALWAYS_INCLUDE_TESTERS)
for base in model_info["bases"]:
if base in MIXIN_TO_TESTER:
tester = MIXIN_TO_TESTER[base]
if tester not in testers:
testers.append(tester)
for attr, tester in ATTRIBUTE_TO_TESTER.items():
if attr in model_info["attributes"]:
value = model_info["attributes"][attr]
if value is not None and value is not False:
if tester not in testers:
testers.append(tester)
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
if "ContextParallelTesterMixin" not in testers:
testers.append("ContextParallelTesterMixin")
# Include AttentionTesterMixin if the model imports attention-related classes
if imports & ATTENTION_INDICATORS:
testers.append("AttentionTesterMixin")
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
testers.append(tester)
return testers
def generate_config_class(model_info: dict, model_name: str) -> str:
class_name = f"{model_name}TesterConfig"
model_class = model_info["name"]
forward_params = model_info.get("forward_params", [])
init_params = model_info.get("init_params", [])
lines = [
f"class {class_name}:",
" @property",
" def model_class(self):",
f" return {model_class}",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
"",
" @property",
" def pretrained_model_kwargs(self):",
' return {"subfolder": "transformer"}',
"",
" @property",
" def generator(self):",
' return torch.Generator("cpu").manual_seed(0)',
"",
" def get_init_dict(self) -> dict[str, int | list[int]]:",
]
if init_params:
lines.append(" # __init__ parameters:")
for param in init_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" return {}",
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
]
)
if forward_params:
lines.append(" # forward() parameters:")
for param in forward_params:
type_str = f": {param['type']}" if param["type"] else ""
default_str = f" = {param['default']}" if param["default"] is not None else ""
lines.append(f" # {param['name']}{type_str}{default_str}")
lines.extend(
[
" # TODO: Fill in dummy inputs",
" return {}",
"",
" @property",
" def input_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
"",
" @property",
" def output_shape(self) -> tuple[int, ...]:",
" return (1, 1)",
]
)
return "\n".join(lines)
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
tester_short = tester.replace("TesterMixin", "")
class_name = f"Test{model_name}{tester_short}"
lines = [f"class {class_name}({config_class}, {tester}):"]
if tester == "TorchCompileTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
elif tester == "IPAdapterTesterMixin":
lines.extend(
[
" @property",
" def ip_adapter_processor_cls(self):",
" return None # TODO: Set processor class",
"",
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
" # TODO: Add IP adapter image embeds to inputs",
" return inputs_dict",
"",
" def create_ip_adapter_state_dict(self, model):",
" # TODO: Create IP adapter state dict",
" return {}",
]
)
elif tester == "SingleFileTesterMixin":
lines.extend(
[
" @property",
" def ckpt_path(self):",
' return "" # TODO: Set checkpoint path',
"",
" @property",
" def alternate_ckpt_paths(self):",
" return []",
"",
" @property",
" def pretrained_model_name_or_path(self):",
' return "" # TODO: Set Hub repository ID',
]
)
elif tester == "GGUFTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization tests",
" return {}",
]
)
elif tester in [
"BitsAndBytesCompileTesterMixin",
"QuantoCompileTesterMixin",
"TorchAoCompileTesterMixin",
"ModelOptCompileTesterMixin",
]:
lines.extend(
[
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester == "GGUFCompileTesterMixin":
lines.extend(
[
" @property",
" def gguf_filename(self):",
' return "" # TODO: Set GGUF filename',
"",
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
" # TODO: Override with larger inputs for quantization compile tests",
" return {}",
]
)
elif tester in [
"PyramidAttentionBroadcastTesterMixin",
"FirstBlockCacheTesterMixin",
"FasterCacheTesterMixin",
]:
lines.append(" pass")
elif tester == "LoraHotSwappingForModelTesterMixin":
lines.extend(
[
" @property",
" def different_shapes_for_compilation(self):",
" return [(4, 4), (4, 8), (8, 8)]",
"",
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
" # TODO: Implement dynamic input generation",
" return {}",
]
)
else:
lines.append(" pass")
return "\n".join(lines)
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
testers = determine_testers(model_info, include_optional, imports)
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
lines = [
"# coding=utf-8",
"# Copyright 2025 HuggingFace Inc.",
"#",
'# Licensed under the Apache License, Version 2.0 (the "License");',
"# you may not use this file except in compliance with the License.",
"# You may obtain a copy of the License at",
"#",
"# http://www.apache.org/licenses/LICENSE-2.0",
"#",
"# Unless required by applicable law or agreed to in writing, software",
'# distributed under the License is distributed on an "AS IS" BASIS,',
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
"# See the License for the specific language governing permissions and",
"# limitations under the License.",
"",
"import torch",
"",
f"from diffusers import {model_info['name']}",
"from diffusers.utils.torch_utils import randn_tensor",
"",
"from ...testing_utils import enable_full_determinism, torch_device",
]
if "LoraTesterMixin" in testers:
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
lines.extend(
[
"from ..testing_utils import (",
*[f" {tester}," for tester in sorted(tester_imports)],
")",
"",
"",
"enable_full_determinism()",
"",
"",
]
)
config_class = f"{model_name}TesterConfig"
lines.append(generate_config_class(model_info, model_name))
lines.append("")
lines.append("")
for tester in testers:
lines.append(generate_test_class(model_name, config_class, tester))
lines.append("")
lines.append("")
if "LoraTesterMixin" in testers:
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
lines.append("")
lines.append("")
return "\n".join(lines).rstrip() + "\n"
def get_test_output_path(model_filepath: str) -> str:
path = Path(model_filepath)
model_filename = path.stem
if "transformers" in path.parts:
return f"tests/models/transformers/test_models_{model_filename}.py"
elif "unets" in path.parts:
return f"tests/models/unets/test_models_{model_filename}.py"
elif "autoencoders" in path.parts:
return f"tests/models/autoencoders/test_models_{model_filename}.py"
else:
return f"tests/models/test_models_{model_filename}.py"
def main():
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
parser.add_argument(
"model_filepath",
type=str,
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
)
parser.add_argument(
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
)
parser.add_argument(
"--include",
"-i",
type=str,
nargs="*",
default=[],
choices=[
"bnb",
"quanto",
"torchao",
"gguf",
"modelopt",
"bnb_compile",
"quanto_compile",
"torchao_compile",
"gguf_compile",
"modelopt_compile",
"pab_cache",
"fbc_cache",
"faster_cache",
"single_file",
"ip_adapter",
"all",
],
help="Optional testers to include",
)
parser.add_argument(
"--class-name",
"-c",
type=str,
default=None,
help="Specific model class to generate tests for (default: first model class found)",
)
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
args = parser.parse_args()
if not Path(args.model_filepath).exists():
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
sys.exit(1)
model_classes, imports = analyze_model_file(args.model_filepath)
if not model_classes:
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
sys.exit(1)
if args.class_name:
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
if not model_info:
available = [m["name"] for m in model_classes]
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
sys.exit(1)
else:
model_info = model_classes[0]
if len(model_classes) > 1:
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
print("Use --class-name to specify a different class", file=sys.stderr)
include_optional = args.include
if "all" in include_optional:
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
if args.dry_run:
print(generated_code)
else:
output_path = args.output or get_test_output_path(args.model_filepath)
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.write(generated_code)
print(f"Generated test file: {output_path}")
print(f"Model class: {model_info['name']}")
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
if __name__ == "__main__":
main()