mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-03 01:15:10 +08:00
Compare commits
7 Commits
main
...
modular-wa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f97055d5f | ||
|
|
85d8d244a1 | ||
|
|
7d272bed80 | ||
|
|
e95c5a2609 | ||
|
|
80275cf18b | ||
|
|
2b74061a11 | ||
|
|
23fb285912 |
@@ -343,34 +343,6 @@ We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
|
||||
|
||||
|
||||
### Ulysses Anything Attention
|
||||
|
||||
The default Ulysses Attention mechanism requires that the sequence length of hidden states must be divisible by the number of devices. This imposes significant limitations on the practical application of Ulysses Attention. [Ulysses Anything Attention](https://github.com/huggingface/diffusers/pull/12996) is a variant of Ulysses Attention that supports arbitrary sequence lengths and arbitrary numbers of attention heads, thereby enhancing the versatility of Ulysses Attention in practical use.
|
||||
|
||||
[`ContextParallelConfig`] supports Ulysses Anything Attention by specifying both `ulysses_degree` and `ulysses_anything`. Please note that Ulysses Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with both `ulysses_degree` set to bigger than 1 and `ulysses_anything=True` to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ulysses_anything=True))
|
||||
```
|
||||
|
||||
> [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency.
|
||||
|
||||
We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulysses Anything Attention with [this script](https://github.com/huggingface/diffusers/pull/12996#issuecomment-3797695999) on a node of 4 L20 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)|
|
||||
|--------------------|------------------|-------------|------------------|------------|
|
||||
| ulysses | 281.07 | 3.56 | 37.11 | 1024x1024 |
|
||||
| ring | 351.34 | 2.85 | 37.01 | 1024x1024 |
|
||||
| unified_balanced | 324.37 | 3.08 | 37.16 | 1024x1024 |
|
||||
| ulysses_anything | 280.94 | 3.56 | 37.11 | 1024x1024 |
|
||||
| ulysses | failed | failed | failed | 1008x1008 |
|
||||
| ring | failed | failed | failed | 1008x1008 |
|
||||
| unified_balanced | failed | failed | failed | 1008x1008 |
|
||||
| ulysses_anything | 278.40 | 3.59 | 36.99 | 1008x1008 |
|
||||
|
||||
From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention.
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
@@ -415,6 +415,7 @@ else:
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
@@ -431,8 +432,13 @@ else:
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"Wan22Blocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanBlocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"WanModularPipeline",
|
||||
"ZImageAutoBlocks",
|
||||
"ZImageModularPipeline",
|
||||
@@ -1151,6 +1157,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
@@ -1167,8 +1174,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
ZImageAutoBlocks,
|
||||
ZImageModularPipeline,
|
||||
|
||||
@@ -11,14 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple, Type, Union
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
@@ -29,10 +27,9 @@ from ..models._modeling_parallel import (
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -211,10 +208,6 @@ class ContextParallelSplitHook(ModelHook):
|
||||
)
|
||||
return x
|
||||
else:
|
||||
if self.parallel_config.ulysses_anything:
|
||||
return PartitionAnythingSharder.shard_anything(
|
||||
x, cp_input.split_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
|
||||
@@ -240,14 +233,7 @@ class ContextParallelGatherHook(ModelHook):
|
||||
for i, cpm in enumerate(self.metadata):
|
||||
if cpm is None:
|
||||
continue
|
||||
if self.parallel_config.ulysses_anything:
|
||||
output[i] = PartitionAnythingSharder.unshard_anything(
|
||||
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
else:
|
||||
output[i] = EquipartitionSharder.unshard(
|
||||
output[i], cpm.gather_dim, self.parallel_config._flattened_mesh
|
||||
)
|
||||
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
|
||||
|
||||
return output[0] if is_tensor else tuple(output)
|
||||
|
||||
@@ -288,73 +274,6 @@ class EquipartitionSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
class AllGatherAnythingFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh):
|
||||
ctx.dim = dim
|
||||
ctx.group = group
|
||||
ctx.world_size = dist.get_world_size(group)
|
||||
ctx.rank = dist.get_rank(group)
|
||||
gathered_tensor = _all_gather_anything(tensor, dim, group)
|
||||
return gathered_tensor
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
||||
# function may return fewer than the specified number of chunks!
|
||||
grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
|
||||
return grad_splits[ctx.rank], None, None
|
||||
|
||||
|
||||
class PartitionAnythingSharder:
|
||||
@classmethod
|
||||
def shard_anything(
|
||||
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
||||
) -> torch.Tensor:
|
||||
assert tensor.size()[dim] >= mesh.size(), (
|
||||
f"Cannot shard tensor of size {tensor.size()} along dim {dim} across mesh of size {mesh.size()}."
|
||||
)
|
||||
# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
|
||||
# function may return fewer than the specified number of chunks!
|
||||
return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
|
||||
|
||||
@classmethod
|
||||
def unshard_anything(
|
||||
cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh
|
||||
) -> torch.Tensor:
|
||||
tensor = tensor.contiguous()
|
||||
tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
def _fill_gather_shapes(shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int) -> List[List[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
rank_shape = list(copy.deepcopy(shape))
|
||||
rank_shape[dim] = gather_dims[i]
|
||||
gather_shapes.append(rank_shape)
|
||||
return gather_shapes
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def _all_gather_anything(tensor: torch.Tensor, dim: int, group: dist.device_mesh.DeviceMesh) -> torch.Tensor:
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
tensor = tensor.contiguous()
|
||||
shape = tensor.shape
|
||||
rank_dim = shape[dim]
|
||||
gather_dims = gather_size_by_comm(rank_dim, group)
|
||||
|
||||
gather_shapes = _fill_gather_shapes(tuple(shape), tuple(gather_dims), dim, world_size)
|
||||
|
||||
gathered_tensors = [torch.empty(shape, device=tensor.device, dtype=tensor.dtype) for shape in gather_shapes]
|
||||
|
||||
dist.all_gather(gathered_tensors, tensor, group=group)
|
||||
gathered_tensor = torch.cat(gathered_tensors, dim=dim)
|
||||
return gathered_tensor
|
||||
|
||||
|
||||
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
|
||||
if name.count("*") > 1:
|
||||
raise ValueError("Wildcard '*' can only be used once in the name")
|
||||
|
||||
@@ -2321,14 +2321,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in original_state_dict.keys():
|
||||
if key.startswith("single_blocks."):
|
||||
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
|
||||
elif key.startswith("double_blocks."):
|
||||
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from ..utils import get_logger
|
||||
|
||||
@@ -68,9 +67,6 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
@@ -98,11 +94,6 @@ class ContextParallelConfig:
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
if self.ulysses_anything:
|
||||
if self.ulysses_degree == 1:
|
||||
raise ValueError("ulysses_degree must be greater than 1 for ulysses_anything to be enabled.")
|
||||
if self.ring_degree > 1:
|
||||
raise ValueError("ulysses_anything cannot be enabled when ring_degree > 1.")
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
@@ -266,39 +257,3 @@ ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextPara
|
||||
#
|
||||
# ContextParallelOutput:
|
||||
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
|
||||
|
||||
|
||||
# Below are utility functions for distributed communication in context parallelism.
|
||||
def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
|
||||
r"""Gather the local size from all ranks.
|
||||
size: int, local size return: List[int], list of size from all ranks
|
||||
"""
|
||||
# NOTE(Serving/CP Safety):
|
||||
# Do NOT cache this collective result.
|
||||
#
|
||||
# In "Ulysses Anything" mode, `size` (e.g. per-rank local seq_len / S_LOCAL)
|
||||
# may legitimately differ across ranks. If we cache based on the *local* `size`,
|
||||
# different ranks can have different cache hit/miss patterns across time.
|
||||
#
|
||||
# That can lead to a catastrophic distributed hang:
|
||||
# - some ranks hit cache and *skip* dist.all_gather()
|
||||
# - other ranks miss cache and *enter* dist.all_gather()
|
||||
# This mismatched collective participation will stall the process group and
|
||||
# eventually trigger NCCL watchdog timeouts (often surfacing later as ALLTOALL
|
||||
# timeouts in Ulysses attention).
|
||||
world_size = dist.get_world_size(group=group)
|
||||
# HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
|
||||
comm_backends = str(dist.get_backend(group=group))
|
||||
# NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
|
||||
gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator()
|
||||
gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)]
|
||||
dist.all_gather(
|
||||
gathered_sizes,
|
||||
torch.tensor([size], device=gather_device, dtype=torch.int64),
|
||||
group=group,
|
||||
)
|
||||
|
||||
gathered_sizes = [s[0].item() for s in gathered_sizes]
|
||||
# NOTE: DON'T use tolist here due to graph break - Explanation:
|
||||
# Backend compiler `inductor` failed with aten._local_scalar_dense.default
|
||||
return gathered_sizes
|
||||
|
||||
@@ -21,8 +21,6 @@ from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if torch.distributed.is_available():
|
||||
@@ -46,8 +44,6 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1283,154 +1279,6 @@ class SeqAllToAllDim(torch.autograd.Function):
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
# Below are helper functions to handle abritrary head num and abritrary sequence length for Ulysses Anything Attention.
|
||||
def _maybe_pad_qkv_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
|
||||
r"""Maybe pad the head dimension to be divisible by world_size.
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H, D) H: int, original global head num return: Tuple[torch.Tensor, int], padded
|
||||
tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
H_PAD = 0
|
||||
if H % world_size != 0:
|
||||
H_PAD = world_size - (H % world_size)
|
||||
NEW_H_LOCAL = (H + H_PAD) // world_size
|
||||
# e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
|
||||
# NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
|
||||
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
|
||||
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
|
||||
return x, H_PAD
|
||||
|
||||
|
||||
def _maybe_unpad_qkv_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
|
||||
r"""Maybe unpad the head dimension.
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
|
||||
unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
|
||||
"""
|
||||
rank = dist.get_rank(group=group)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
# Only the last rank may have padding
|
||||
if H_PAD > 0 and rank == world_size - 1:
|
||||
x = x[:, :, :-H_PAD, :]
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def _maybe_pad_o_head(x: torch.Tensor, H: int, group: dist.ProcessGroup) -> Tuple[torch.Tensor, int]:
|
||||
r"""Maybe pad the head dimension to be divisible by world_size.
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) H: int, original global head num return: Tuple[torch.Tensor, int],
|
||||
padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
|
||||
"""
|
||||
if H is None:
|
||||
return x, 0
|
||||
|
||||
rank = dist.get_rank(group=group)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
H_PAD = 0
|
||||
# Only the last rank may need padding
|
||||
if H % world_size != 0:
|
||||
# We need to broadcast H_PAD to all ranks to keep consistency
|
||||
# in unpadding step later for all ranks.
|
||||
H_PAD = world_size - (H % world_size)
|
||||
NEW_H_LOCAL = (H + H_PAD) // world_size
|
||||
assert H_PAD < NEW_H_LOCAL, f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
|
||||
if rank == world_size - 1:
|
||||
x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
|
||||
return x, H_PAD
|
||||
|
||||
|
||||
def _maybe_unpad_o_head(x: torch.Tensor, H_PAD: int, group: dist.ProcessGroup) -> torch.Tensor:
|
||||
r"""Maybe unpad the head dimension.
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D) H_PAD: int, head padding num return: torch.Tensor,
|
||||
unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
|
||||
"""
|
||||
if H_PAD > 0:
|
||||
x = x[:, :, :-H_PAD, :]
|
||||
return x.contiguous()
|
||||
|
||||
|
||||
def ulysses_anything_metadata(query: torch.Tensor, **kwargs) -> dict:
|
||||
# query: (B, S_LOCAL, H_GLOBAL, D)
|
||||
assert len(query.shape) == 4, "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
|
||||
extra_kwargs = {}
|
||||
extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
|
||||
extra_kwargs["Q_S_LOCAL"] = query.shape[1]
|
||||
# Add other kwargs if needed in future
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def all_to_all_single_any_qkv_async(
|
||||
x: torch.Tensor, group: dist.ProcessGroup, **kwargs
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
r"""
|
||||
x: torch.Tensor, shape (B, S_LOCAL, H, D) return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
|
||||
"""
|
||||
world_size = dist.get_world_size(group=group)
|
||||
B, S_LOCAL, H, D = x.shape
|
||||
x, H_PAD = _maybe_pad_qkv_head(x, H, group)
|
||||
H_LOCAL = (H + H_PAD) // world_size
|
||||
# (world_size, S_LOCAL, B, H_LOCAL, D)
|
||||
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
|
||||
|
||||
input_split_sizes = [S_LOCAL] * world_size
|
||||
# S_LOCAL maybe not equal for all ranks in dynamic shape case,
|
||||
# since we don't know the actual shape before this timing, thus,
|
||||
# we have to use all gather to collect the S_LOCAL first.
|
||||
output_split_sizes = gather_size_by_comm(S_LOCAL, group)
|
||||
x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
|
||||
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
|
||||
|
||||
def wait() -> torch.Tensor:
|
||||
nonlocal x, H_PAD
|
||||
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
|
||||
# (S_GLOBAL, B, H_LOCAL, D)
|
||||
# -> (B, S_GLOBAL, H_LOCAL, D)
|
||||
x = x.permute(1, 0, 2, 3).contiguous()
|
||||
x = _maybe_unpad_qkv_head(x, H_PAD, group)
|
||||
return x
|
||||
|
||||
return wait
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
def all_to_all_single_any_o_async(x: torch.Tensor, group: dist.ProcessGroup, **kwargs) -> Callable[..., torch.Tensor]:
|
||||
r"""
|
||||
x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D) return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
|
||||
"""
|
||||
# Assume H is provided in kwargs, since we can't infer H from x's shape.
|
||||
# The padding logic needs H to determine if padding is necessary.
|
||||
H = kwargs.get("NUM_QO_HEAD", None)
|
||||
world_size = dist.get_world_size(group=group)
|
||||
|
||||
x, H_PAD = _maybe_pad_o_head(x, H, group)
|
||||
shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
|
||||
(B, S_GLOBAL, H_LOCAL, D) = shape
|
||||
|
||||
# input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
|
||||
# output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]
|
||||
|
||||
# WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
|
||||
# from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
|
||||
# c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
|
||||
# b.tensor_split(4)[0].shape[1])
|
||||
|
||||
S_LOCAL = kwargs.get("Q_S_LOCAL")
|
||||
input_split_sizes = gather_size_by_comm(S_LOCAL, group)
|
||||
x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
|
||||
output_split_sizes = [S_LOCAL] * world_size
|
||||
x = funcol.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
|
||||
|
||||
def wait() -> torch.Tensor:
|
||||
nonlocal x, H_PAD
|
||||
x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
|
||||
x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
|
||||
x = x.permute(2, 1, 0, 3, 4).contiguous()
|
||||
x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
|
||||
x = _maybe_unpad_o_head(x, H_PAD, group)
|
||||
return x
|
||||
|
||||
return wait
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1653,82 +1501,6 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
group = ulysses_mesh.get_group()
|
||||
|
||||
ctx.forward_op = forward_op
|
||||
ctx.backward_op = backward_op
|
||||
ctx._parallel_config = _parallel_config
|
||||
|
||||
metadata = ulysses_anything_metadata(query)
|
||||
query_wait = all_to_all_single_any_qkv_async(query, group, **metadata)
|
||||
key_wait = all_to_all_single_any_qkv_async(key, group, **metadata)
|
||||
value_wait = all_to_all_single_any_qkv_async(value, group, **metadata)
|
||||
|
||||
query = query_wait() # type: torch.Tensor
|
||||
key = key_wait() # type: torch.Tensor
|
||||
value = value_wait() # type: torch.Tensor
|
||||
|
||||
out = forward_op(
|
||||
ctx,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
_save_ctx=False, # ulysses anything only support forward pass now.
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
# out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
|
||||
out_wait = all_to_all_single_any_o_async(out, group, **metadata)
|
||||
|
||||
if return_lse:
|
||||
# lse: (B, S_Q_GLOBAL, H_LOCAL)
|
||||
lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
|
||||
lse_wait = all_to_all_single_any_o_async(lse, group, **metadata)
|
||||
out = out_wait() # type: torch.Tensor
|
||||
lse = lse_wait() # type: torch.Tensor
|
||||
lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
|
||||
else:
|
||||
out = out_wait() # type: torch.Tensor
|
||||
lse = None
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
def backward(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
raise NotImplementedError("Backward pass for Ulysses Anything Attention in diffusers is not implemented yet.")
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -1846,37 +1618,20 @@ def _templated_context_parallel_attention(
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ulysses_degree > 1:
|
||||
if _parallel_config.context_parallel_config.ulysses_anything:
|
||||
# For Any sequence lengths and Any head num support
|
||||
return TemplatedUlyssesAnythingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
else:
|
||||
return TemplatedUlyssesAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
return TemplatedUlyssesAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
|
||||
|
||||
|
||||
@@ -45,7 +45,16 @@ else:
|
||||
"InsertableDict",
|
||||
]
|
||||
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
|
||||
_import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
|
||||
_import_structure["wan"] = [
|
||||
"WanBlocks",
|
||||
"Wan22Blocks",
|
||||
"WanImage2VideoAutoBlocks",
|
||||
"Wan22Image2VideoBlocks",
|
||||
"WanModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -58,6 +67,7 @@ else:
|
||||
"Flux2KleinBaseAutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
@@ -88,6 +98,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
Flux2KleinBaseModularPipeline,
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
@@ -112,7 +123,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||
from .wan import (
|
||||
Wan22Blocks,
|
||||
Wan22Image2VideoBlocks,
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanBlocks,
|
||||
WanImage2VideoAutoBlocks,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -55,7 +55,11 @@ else:
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline", "Flux2KleinModularPipeline"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -101,7 +105,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
@@ -59,46 +57,35 @@ class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
return num_channels_latents
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
class Flux2KleinModularPipeline(Flux2ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein.
|
||||
A ModularPipeline for Flux2-Klein (distilled model).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinAutoBlocks"
|
||||
else:
|
||||
return "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
default_blocks_name = "Flux2KleinAutoBlocks"
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
if hasattr(self.config, "is_distilled") and self.config.is_distilled:
|
||||
return False
|
||||
|
||||
requires_unconditional_embeds = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(Flux2ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Flux2-Klein (base model).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2KleinBaseAutoBlocks"
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
|
||||
@@ -52,19 +52,61 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# map regular pipeline to modular pipeline class name
|
||||
|
||||
|
||||
def _create_default_map_fn(pipeline_class_name: str):
|
||||
"""Create a mapping function that always returns the same pipeline class."""
|
||||
|
||||
def _map_fn(config_dict=None):
|
||||
return pipeline_class_name
|
||||
|
||||
return _map_fn
|
||||
|
||||
|
||||
def _flux2_klein_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "Flux2KleinModularPipeline"
|
||||
|
||||
if "is_distilled" in config_dict and config_dict["is_distilled"]:
|
||||
return "Flux2KleinModularPipeline"
|
||||
else:
|
||||
return "Flux2KleinBaseModularPipeline"
|
||||
|
||||
|
||||
def _wan_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22ModularPipeline"
|
||||
else:
|
||||
return "WanModularPipeline"
|
||||
|
||||
|
||||
def _wan_i2v_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
if "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22Image2VideoModularPipeline"
|
||||
else:
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("flux2-klein", "Flux2KleinModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
|
||||
("wan", _wan_map_fn),
|
||||
("wan-i2v", _wan_i2v_map_fn),
|
||||
("flux", _create_default_map_fn("FluxModularPipeline")),
|
||||
("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")),
|
||||
("flux2", _create_default_map_fn("Flux2ModularPipeline")),
|
||||
("flux2-klein", _flux2_klein_map_fn),
|
||||
("qwenimage", _create_default_map_fn("QwenImageModularPipeline")),
|
||||
("qwenimage-edit", _create_default_map_fn("QwenImageEditModularPipeline")),
|
||||
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
|
||||
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -366,7 +408,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||
"""
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(self.model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn()
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
|
||||
@@ -1545,7 +1588,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if modular_config_dict is not None:
|
||||
blocks_class_name = modular_config_dict.get("_blocks_class_name")
|
||||
else:
|
||||
blocks_class_name = self.get_default_blocks_name(config_dict)
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
@@ -1617,9 +1660,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
return self.default_blocks_name
|
||||
|
||||
@classmethod
|
||||
def _load_pipeline_config(
|
||||
cls,
|
||||
@@ -1715,7 +1755,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
logger.debug(" try to determine the modular pipeline class from model_index.json")
|
||||
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
|
||||
model_name = _get_model(standard_pipeline_class.__name__)
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
|
||||
map_fn = MODULAR_PIPELINE_MAPPING.get(model_name, _create_default_map_fn("ModularPipeline"))
|
||||
pipeline_class_name = map_fn(config_dict)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
pipeline_class = getattr(diffusers_module, pipeline_class_name)
|
||||
else:
|
||||
|
||||
@@ -21,16 +21,16 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["WanTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"Wan22AutoBlocks",
|
||||
"WanAutoBlocks",
|
||||
"WanAutoImageEncoderStep",
|
||||
"WanAutoVaeImageEncoderStep",
|
||||
_import_structure["modular_blocks_wan"] = ["WanBlocks"]
|
||||
_import_structure["modular_blocks_wan22"] = ["Wan22Blocks"]
|
||||
_import_structure["modular_blocks_wan22_i2v"] = ["Wan22Image2VideoBlocks"]
|
||||
_import_structure["modular_blocks_wan_i2v"] = ["WanImage2VideoAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
"Wan22ModularPipeline",
|
||||
"WanImage2VideoModularPipeline",
|
||||
"WanModularPipeline",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["WanModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -39,16 +39,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .encoders import WanTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
Wan22AutoBlocks,
|
||||
WanAutoBlocks,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
from .modular_blocks_wan import WanBlocks
|
||||
from .modular_blocks_wan22 import Wan22Blocks
|
||||
from .modular_blocks_wan22_i2v import Wan22Image2VideoBlocks
|
||||
from .modular_blocks_wan_i2v import WanImage2VideoAutoBlocks
|
||||
from .modular_pipeline import (
|
||||
Wan22Image2VideoModularPipeline,
|
||||
Wan22ModularPipeline,
|
||||
WanImage2VideoModularPipeline,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from .modular_pipeline import WanModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -280,7 +280,7 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["first_frame_latents"],
|
||||
image_latent_inputs: List[str] = ["image_condition_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
@@ -294,20 +294,16 @@ class WanAdditionalInputsStep(ModularPipelineBlocks):
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||
a single string or list of strings. Defaults to ["first_frame_latents"].
|
||||
a single string or list of strings. Defaults to ["image_condition_latents"].
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to [].
|
||||
|
||||
Examples:
|
||||
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
|
||||
# Configure to process image_condition_latents (default behavior) WanAdditionalInputsStep() # Configure to
|
||||
process image latents and additional batch inputs WanAdditionalInputsStep(
|
||||
image_latent_inputs=["image_condition_latents"], additional_batch_inputs=["image_embeds"]
|
||||
)
|
||||
"""
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
@@ -557,81 +553,3 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.first_last_frame_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class WanImageVaeDecoderStep(ModularPipelineBlocks):
|
||||
class WanVaeDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
|
||||
@@ -89,52 +89,10 @@ class WanImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_frame_latents",
|
||||
"image_condition_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first frame latents to use for the denoising process. Can be generated in prepare_first_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs. Can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat([block_state.latents, block_state.first_frame_latents], dim=1).to(
|
||||
block_state.dtype
|
||||
)
|
||||
return components, block_state
|
||||
|
||||
|
||||
class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `WanDenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||
),
|
||||
InputParam(
|
||||
"first_last_frame_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The first and last frame latents to use for the denoising process. Can be generated in prepare_first_last_frame_latents step.",
|
||||
description="The image condition latents to use for the denoising process. Can be generated in prepare_first_frame_latents/prepare_first_last_frame_latents step.",
|
||||
),
|
||||
InputParam(
|
||||
"dtype",
|
||||
@@ -147,7 +105,7 @@ class WanFLF2VLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = torch.cat(
|
||||
[block_state.latents, block_state.first_last_frame_latents], dim=1
|
||||
[block_state.latents, block_state.image_condition_latents], dim=1
|
||||
).to(block_state.dtype)
|
||||
return components, block_state
|
||||
|
||||
@@ -584,29 +542,3 @@ class Wan22Image2VideoDenoiseStep(WanDenoiseLoopWrapper):
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports image-to-video tasks for Wan2.2."
|
||||
)
|
||||
|
||||
|
||||
class WanFLF2VDenoiseStep(WanDenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
WanFLF2VLoopBeforeDenoiser,
|
||||
WanLoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
|
||||
"encoder_hidden_states_image": "image_embeds",
|
||||
}
|
||||
),
|
||||
WanLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `WanFLF2VLoopBeforeDenoiser`\n"
|
||||
" - `WanLoopDenoiser`\n"
|
||||
" - `WanLoopAfterDenoiser`\n"
|
||||
"This block supports FLF2V tasks for wan2.1."
|
||||
)
|
||||
|
||||
@@ -468,7 +468,7 @@ class WanFirstLastFrameImageEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
class WanVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -493,7 +493,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -564,7 +564,51 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked first frame latents and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanFirstLastFrameVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
@@ -590,7 +634,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
InputParam("resized_last_image", type_hint=PIL.Image.Image, required=True),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
InputParam("num_frames"),
|
||||
InputParam("num_frames", type_hint=int, default=81),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@@ -667,3 +711,49 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "wan"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_frames", type_hint=int, required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("image_condition_latents", type_hint=Optional[torch.Tensor]),
|
||||
]
|
||||
|
||||
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
|
||||
|
||||
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
|
||||
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
|
||||
|
||||
first_frame_mask = mask_lat_size[:, :, 0:1]
|
||||
first_frame_mask = torch.repeat_interleave(
|
||||
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
|
||||
)
|
||||
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
||||
mask_lat_size = mask_lat_size.view(
|
||||
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
|
||||
)
|
||||
mask_lat_size = mask_lat_size.transpose(1, 2)
|
||||
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
|
||||
block_state.image_condition_latents = torch.concat(
|
||||
[mask_lat_size, block_state.first_last_frame_latents], dim=1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
@@ -1,474 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
WanDenoiseStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeImageEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# wan2.1
|
||||
# wan2.1: text2vid
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: image2video
|
||||
## image encoder
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: FLF2v
|
||||
|
||||
|
||||
## image encoder
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
## denoise
|
||||
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanFLF2VDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_last_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.1: auto blocks
|
||||
## image encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
## denoise
|
||||
class WanAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
WanFLF2VCoreDenoiseStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["flf2v", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto pipeline blocks
|
||||
class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
WanAutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# wan22
|
||||
# wan2.2: text2vid
|
||||
|
||||
|
||||
## denoise
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# wan2.2: image2video
|
||||
## denoise
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_first_frame_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2video", "text2video"]
|
||||
block_trigger_inputs = ["first_frame_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2video and image2video tasks."
|
||||
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
|
||||
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
|
||||
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoVaeImageEncoderStep,
|
||||
Wan22AutoDenoiseStep,
|
||||
WanImageVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ "- for text-to-video generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# presets for wan2.1 and wan2.2
|
||||
# YiYi Notes: should we move these to doc?
|
||||
# wan2.1
|
||||
TEXT2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", WanDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
|
||||
("denoise", WanImage2VideoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
FLF2V_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
|
||||
("denoise", WanFLF2VDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# wan2.2 presets
|
||||
|
||||
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
("denoise", Wan22DenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
# presets all blocks (wan and wan22)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"wan2.1": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS,
|
||||
"flf2v": FLF2V_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
},
|
||||
"wan2.2": {
|
||||
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
|
||||
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
|
||||
"auto": AUTO_BLOCKS_WAN22,
|
||||
},
|
||||
}
|
||||
83
src/diffusers/modular_pipelines/wan/modular_blocks_wan.py
Normal file
83
src/diffusers/modular_pipelines/wan/modular_blocks_wan.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.1 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class WanBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline blocks for Wan2.1.\n"
|
||||
+ "- `WanTextEncoderStep` is used to encode the text\n"
|
||||
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
|
||||
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
|
||||
)
|
||||
88
src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py
Normal file
88
src/diffusers/modular_pipelines/wan/modular_blocks_wan22.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22DenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanTextEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
|
||||
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22DenoiseStep,
|
||||
]
|
||||
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. BLOCKS (Wan2.2 text2video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Blocks(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
Wan22CoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
117
src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py
Normal file
117
src/diffusers/modular_pipelines/wan/modular_blocks_wan22_i2v.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
Wan22Image2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. BLOCKS (Wan2.2 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanImage2VideoVaeEncoderStep,
|
||||
Wan22Image2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for image-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
|
||||
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
203
src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py
Normal file
203
src/diffusers/modular_pipelines/wan/modular_blocks_wan_i2v.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
WanTextInputStep,
|
||||
)
|
||||
from .decoders import WanVaeDecoderStep
|
||||
from .denoise import (
|
||||
WanImage2VideoDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
WanFirstLastFrameImageEncoderStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanImageCropResizeStep,
|
||||
WanImageEncoderStep,
|
||||
WanImageResizeStep,
|
||||
WanPrepareFirstFrameLatentsStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
WanTextEncoderStep,
|
||||
WanVaeEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# ====================
|
||||
# 1. IMAGE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
|
||||
|
||||
|
||||
# wan2.1 Auto Image Encoder
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
model_name = "wan-i2v"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image Encoder step that encode the image to generate the image embeddings"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanImageResizeStep,
|
||||
WanImageCropResizeStep,
|
||||
WanFirstLastFrameVaeEncoderStep,
|
||||
WanPrepareFirstLastFrameLatentsStep,
|
||||
]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder", "prepare_first_last_frame_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
|
||||
|
||||
|
||||
# wan2.1 Auto Vae Encoder
|
||||
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+ "This is an auto pipeline block that works for image2video tasks."
|
||||
+ " - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided."
|
||||
+ " - `WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided."
|
||||
+ " - if `last_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (inputs -> set_timesteps -> prepare_latents -> denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 I2V core denoise (support both I2V and FLF2V)
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
WanAdditionalInputsStep(image_latent_inputs=["image_condition_latents"]),
|
||||
WanSetTimestepsStep,
|
||||
WanPrepareLatentsStep,
|
||||
WanImage2VideoDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. BLOCKS (Wan2.1 Image2Video)
|
||||
# ====================
|
||||
|
||||
|
||||
# wan2.1 Image2Video Auto Blocks
|
||||
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
WanAutoImageEncoderStep,
|
||||
WanAutoVaeEncoderStep,
|
||||
WanImage2VideoCoreDenoiseStep,
|
||||
WanVaeDecoderStep,
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for image-to-video using Wan.\n"
|
||||
+ "- for I2V workflow, all you need to provide is `image`"
|
||||
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
|
||||
)
|
||||
@@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...pipelines.pipeline_utils import StableDiffusionMixin
|
||||
from ...utils import logging
|
||||
@@ -30,19 +28,12 @@ class WanModularPipeline(
|
||||
WanLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Wan.
|
||||
A ModularPipeline for Wan2.1 text2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanAutoBlocks"
|
||||
|
||||
# override the default_blocks_name in base class, which is just return self.default_blocks_name
|
||||
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
|
||||
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
|
||||
return "Wan22AutoBlocks"
|
||||
else:
|
||||
return "WanAutoBlocks"
|
||||
default_blocks_name = "WanBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
@@ -118,3 +109,33 @@ class WanModularPipeline(
|
||||
if hasattr(self, "scheduler") and self.scheduler is not None:
|
||||
num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||||
return num_train_timesteps
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.1 image2video (both I2V and FLF2V).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "WanImage2VideoAutoBlocks"
|
||||
|
||||
|
||||
class Wan22ModularPipeline(WanModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 text2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Blocks"
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(Wan22ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Wan2.2 image2video.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Wan22Image2VideoBlocks"
|
||||
|
||||
@@ -246,7 +246,7 @@ AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("wan", WanImageToVideoPipeline),
|
||||
("wan-i2v", WanImageToVideoPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -47,6 +47,21 @@ class Flux2KleinBaseAutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinBaseModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -287,7 +302,7 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
class Wan22Blocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -302,7 +317,82 @@ class Wan22AutoBlocks(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanAutoBlocks(metaclass=DummyObject):
|
||||
class Wan22Image2VideoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22Image2VideoModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Wan22ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class WanImage2VideoModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -32,22 +32,6 @@ warnings.simplefilter(action="ignore", category=FutureWarning)
|
||||
|
||||
def pytest_configure(config):
|
||||
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
|
||||
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
|
||||
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
|
||||
config.addinivalue_line("markers", "training: marks tests for training functionality")
|
||||
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
|
||||
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
|
||||
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
|
||||
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
|
||||
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
|
||||
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
|
||||
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
|
||||
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
|
||||
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
|
||||
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
|
||||
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
|
||||
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
|
||||
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
|
||||
config.addinivalue_line("markers", "slow: mark test as slow")
|
||||
config.addinivalue_line("markers", "nightly: mark test as nightly")
|
||||
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheConfigMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PyramidAttentionBroadcastConfigMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
)
|
||||
from .common import BaseModelTesterConfig, ModelTesterMixin
|
||||
from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFConfigMixin,
|
||||
GGUFTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptConfigMixin,
|
||||
ModelOptTesterMixin,
|
||||
QuantizationCompileTesterMixin,
|
||||
QuantizationTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoConfigMixin,
|
||||
QuantoTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoConfigMixin,
|
||||
TorchAoTesterMixin,
|
||||
)
|
||||
from .single_file import SingleFileTesterMixin
|
||||
from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"BitsAndBytesConfigMixin",
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
"FirstBlockCacheConfigMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"GGUFCompileTesterMixin",
|
||||
"GGUFConfigMixin",
|
||||
"GGUFTesterMixin",
|
||||
"GroupOffloadTesterMixin",
|
||||
"IPAdapterTesterMixin",
|
||||
"LayerwiseCastingTesterMixin",
|
||||
"LoraHotSwappingForModelTesterMixin",
|
||||
"LoraTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
"ModelOptConfigMixin",
|
||||
"ModelOptTesterMixin",
|
||||
"ModelTesterMixin",
|
||||
"PyramidAttentionBroadcastConfigMixin",
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"QuantizationCompileTesterMixin",
|
||||
"QuantizationTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"QuantoConfigMixin",
|
||||
"QuantoTesterMixin",
|
||||
"SingleFileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"TorchAoConfigMixin",
|
||||
"TorchAoTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
"TrainingTesterMixin",
|
||||
]
|
||||
@@ -1,181 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention processor and module functionality on models.
|
||||
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if not hasattr(model, "fuse_qkv_projections"):
|
||||
pytest.skip("Model does not support QKV projection fusion.")
|
||||
|
||||
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model.fuse_qkv_projections()
|
||||
|
||||
has_fused_projections = False
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
|
||||
has_fused_projections = True
|
||||
assert module.fused_projections, "fused_projections flag should be True"
|
||||
break
|
||||
|
||||
if has_fused_projections:
|
||||
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_fusion,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should not change after fusing projections",
|
||||
)
|
||||
|
||||
model.unfuse_qkv_projections()
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
|
||||
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
|
||||
assert not module.fused_projections, "fused_projections flag should be False"
|
||||
|
||||
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
output_before_fusion,
|
||||
output_after_unfusion,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match original after unfusing projections",
|
||||
)
|
||||
|
||||
def test_get_set_processor(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# Check if model has attention processors
|
||||
if not hasattr(model, "attn_processors"):
|
||||
pytest.skip("Model does not have attention processors.")
|
||||
|
||||
# Test getting processors
|
||||
processors = model.attn_processors
|
||||
assert isinstance(processors, dict), "attn_processors should return a dict"
|
||||
assert len(processors) > 0, "Model should have at least one attention processor"
|
||||
|
||||
# Test that all processors can be retrieved via get_processor
|
||||
for module in model.modules():
|
||||
if isinstance(module, AttentionModuleMixin):
|
||||
processor = module.get_processor()
|
||||
assert processor is not None, "get_processor should return a processor"
|
||||
|
||||
# Test setting a new processor
|
||||
new_processor = AttnProcessor()
|
||||
module.set_processor(new_processor)
|
||||
retrieved_processor = module.get_processor()
|
||||
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
|
||||
|
||||
def test_attention_processor_dict(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict of new processors
|
||||
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
|
||||
|
||||
# Set processors using dict
|
||||
model.set_attn_processor(new_processors)
|
||||
|
||||
# Verify all processors were set
|
||||
updated_processors = model.attn_processors
|
||||
for key in current_processors.keys():
|
||||
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
|
||||
|
||||
def test_attention_processor_count_mismatch_raises_error(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
if not hasattr(model, "set_attn_processor"):
|
||||
pytest.skip("Model does not support setting attention processors.")
|
||||
|
||||
# Get current processors
|
||||
current_processors = model.attn_processors
|
||||
|
||||
# Create a dict with wrong number of processors
|
||||
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
|
||||
|
||||
# Verify error is raised
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
@@ -1,556 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
|
||||
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from diffusers.models.cache_utils import CacheMixin
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
|
||||
|
||||
|
||||
def require_cache_mixin(func):
|
||||
"""Decorator to skip tests if model doesn't use CacheMixin."""
|
||||
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not issubclass(self.model_class, CacheMixin):
|
||||
pytest.skip(f"{self.model_class.__name__} does not use CacheMixin.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CacheTesterMixin:
|
||||
"""
|
||||
Base mixin class providing common test implementations for cache testing.
|
||||
|
||||
Cache-specific mixins should:
|
||||
1. Inherit from their respective config mixin (e.g., PyramidAttentionBroadcastConfigMixin)
|
||||
2. Inherit from this mixin
|
||||
3. Define the cache config to use for tests
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods in test classes:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Optional overrides:
|
||||
- cache_input_key: Property returning the input tensor key to vary between passes (default: "hidden_states")
|
||||
"""
|
||||
|
||||
@property
|
||||
def cache_input_key(self):
|
||||
return "hidden_states"
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_cache_config(self):
|
||||
"""
|
||||
Get the cache config for testing.
|
||||
Should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_cache_config")
|
||||
|
||||
def _get_hook_names(self):
|
||||
"""
|
||||
Get the hook names to check for this cache type.
|
||||
Should be implemented by subclasses.
|
||||
Returns a list of hook name strings.
|
||||
"""
|
||||
raise NotImplementedError("Subclass must implement _get_hook_names")
|
||||
|
||||
def _test_cache_enable_disable_state(self):
|
||||
"""Test that cache enable/disable updates the is_cache_enabled state correctly."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Initially cache should not be enabled
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled initially."
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
# Enable cache
|
||||
model.enable_cache(config)
|
||||
assert model.is_cache_enabled, "Cache should be enabled after enable_cache()."
|
||||
|
||||
# Disable cache
|
||||
model.disable_cache()
|
||||
assert not model.is_cache_enabled, "Cache should not be enabled after disable_cache()."
|
||||
|
||||
def _test_cache_double_enable_raises_error(self):
|
||||
"""Test that enabling cache twice raises an error."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Trying to enable again should raise ValueError
|
||||
with pytest.raises(ValueError, match="Caching has already been enabled"):
|
||||
model.enable_cache(config)
|
||||
|
||||
# Cleanup
|
||||
model.disable_cache()
|
||||
|
||||
def _test_cache_hooks_registered(self):
|
||||
"""Test that cache hooks are properly registered and removed."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
config = self._get_cache_config()
|
||||
hook_names = self._get_hook_names()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# Check that at least one hook was registered
|
||||
hook_count = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count += 1
|
||||
|
||||
assert hook_count > 0, f"At least one cache hook should be registered. Hook names: {hook_names}"
|
||||
|
||||
# Disable and verify hooks are removed
|
||||
model.disable_cache()
|
||||
|
||||
hook_count_after = 0
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
for hook_name in hook_names:
|
||||
hook = module._diffusers_hook.get_hook(hook_name)
|
||||
if hook is not None:
|
||||
hook_count_after += 1
|
||||
|
||||
assert hook_count_after == 0, "Cache hooks should be removed after disable_cache()."
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with cache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass (vary input tensor to simulate denoising)
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different inputs (produces approximated output)
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_context_manager(self, atol=1e-5, rtol=0):
|
||||
"""Test the cache_context context manager properly isolates cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# Run inference in first context
|
||||
with model.cache_context("context_1"):
|
||||
output_ctx1 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Run same inference in second context (cache should be reset)
|
||||
with model.cache_context("context_2"):
|
||||
output_ctx2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Both contexts should produce the same output (first pass in each)
|
||||
assert_tensors_close(
|
||||
output_ctx1,
|
||||
output_ctx2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="First pass in different cache contexts should produce the same output.",
|
||||
)
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the cache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastConfigMixin:
|
||||
"""
|
||||
Base mixin providing PyramidAttentionBroadcast cache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default PAB config - can be overridden by subclasses
|
||||
PAB_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
}
|
||||
|
||||
# Store timestep for callback (must be within default range (100, 800) for skipping to trigger)
|
||||
_current_timestep = 500
|
||||
|
||||
def _get_cache_config(self):
|
||||
config_kwargs = self.PAB_CONFIG.copy()
|
||||
config_kwargs["current_timestep_callback"] = lambda: self._current_timestep
|
||||
return PyramidAttentionBroadcastConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_PYRAMID_ATTENTION_BROADCAST_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class PyramidAttentionBroadcastTesterMixin(PyramidAttentionBroadcastConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing PyramidAttentionBroadcast caching on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_pab_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FirstBlockCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FBC config - can be overridden by subclasses
|
||||
# Higher threshold makes FBC more aggressive about caching (skips more often)
|
||||
FBC_CONFIG = {
|
||||
"threshold": 1.0,
|
||||
}
|
||||
|
||||
def _get_cache_config(self):
|
||||
return FirstBlockCacheConfig(**self.FBC_CONFIG)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FBC_LEADER_BLOCK_HOOK, _FBC_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FirstBlockCacheTesterMixin(FirstBlockCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FirstBlockCache on models.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FBC cache enabled (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
# FBC requires cache_context to be set for inference
|
||||
with model.cache_context("fbc_test"):
|
||||
# First pass populates the cache
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass - FBC should skip remaining blocks and use cached residuals
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FBC cache state (requires cache_context)."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
with model.cache_context("fbc_test"):
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_fbc_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheConfigMixin:
|
||||
"""
|
||||
Base mixin providing FasterCache config.
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
"""
|
||||
|
||||
# Default FasterCache config - can be overridden by subclasses
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
}
|
||||
|
||||
def _get_cache_config(self, current_timestep_callback=None):
|
||||
config_kwargs = self.FASTER_CACHE_CONFIG.copy()
|
||||
if current_timestep_callback is None:
|
||||
current_timestep_callback = lambda: 1000 # noqa: E731
|
||||
config_kwargs["current_timestep_callback"] = current_timestep_callback
|
||||
return FasterCacheConfig(**config_kwargs)
|
||||
|
||||
def _get_hook_names(self):
|
||||
return [_FASTER_CACHE_DENOISER_HOOK, _FASTER_CACHE_BLOCK_HOOK]
|
||||
|
||||
|
||||
@is_cache
|
||||
class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
|
||||
"""
|
||||
Mixin class for testing FasterCache on models.
|
||||
|
||||
Note: FasterCache is designed for pipeline-level inference with proper CFG batch handling
|
||||
and timestep management. Inference tests are skipped at model level - FasterCache should
|
||||
be tested via pipeline tests (e.g., FluxPipeline, HunyuanVideoPipeline).
|
||||
|
||||
Expected class attributes:
|
||||
- model_class: The model class to test (must use CacheMixin)
|
||||
|
||||
Expected methods to be implemented by subclasses:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cache
|
||||
Use `pytest -m "not cache"` to skip these tests
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_cache_inference(self):
|
||||
"""Test that model can run inference with FasterCache enabled."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
current_timestep = [1000]
|
||||
config = self._get_cache_config(current_timestep_callback=lambda: current_timestep[0])
|
||||
|
||||
model.enable_cache(config)
|
||||
|
||||
# First pass with timestep outside skip range - computes and populates cache
|
||||
current_timestep[0] = 1000
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# Move timestep inside skip range so subsequent passes use cache
|
||||
current_timestep[0] = 500
|
||||
|
||||
# Create modified inputs for second pass
|
||||
inputs_dict_step2 = inputs_dict.copy()
|
||||
if self.cache_input_key in inputs_dict_step2:
|
||||
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
|
||||
inputs_dict_step2[self.cache_input_key]
|
||||
)
|
||||
|
||||
# Second pass uses cached attention with different inputs
|
||||
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
assert output_with_cache is not None, "Model output should not be None with cache enabled."
|
||||
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
|
||||
|
||||
# Run same inputs without cache to compare
|
||||
model.disable_cache()
|
||||
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
|
||||
|
||||
# Cached output should be different from non-cached output (due to approximation)
|
||||
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
|
||||
"Cached output should be different from non-cached output due to cache approximation."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def _test_reset_stateful_cache(self):
|
||||
"""Test that _reset_stateful_cache resets the FasterCache state."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
|
||||
config = self._get_cache_config()
|
||||
model.enable_cache(config)
|
||||
|
||||
_ = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model._reset_stateful_cache()
|
||||
|
||||
model.disable_cache()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_enable_disable_state(self):
|
||||
self._test_cache_enable_disable_state()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_double_enable_raises_error(self):
|
||||
self._test_cache_double_enable_raises_error()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_hooks_registered(self):
|
||||
self._test_cache_hooks_registered()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_inference(self):
|
||||
self._test_cache_inference()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_context_manager(self):
|
||||
self._test_cache_context_manager()
|
||||
|
||||
@require_cache_mixin
|
||||
def test_faster_cache_reset_stateful_cache(self):
|
||||
self._test_reset_stateful_cache()
|
||||
@@ -1,666 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
|
||||
|
||||
from diffusers.utils import SAFE_WEIGHTS_INDEX_NAME, _add_variant, logging
|
||||
from diffusers.utils.testing_utils import require_accelerator, require_torch_multi_accelerator
|
||||
|
||||
from ...testing_utils import assert_tensors_close, torch_device
|
||||
|
||||
|
||||
def named_persistent_module_tensors(
|
||||
module: nn.Module,
|
||||
recurse: bool = False,
|
||||
):
|
||||
"""
|
||||
A helper function that gathers all the tensors (parameters + persistent buffers) of a given module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module we want the tensors on.
|
||||
recurse (`bool`, *optional`, defaults to `False`):
|
||||
Whether or not to go look in every submodule or just return the direct parameters and buffers.
|
||||
"""
|
||||
yield from module.named_parameters(recurse=recurse)
|
||||
|
||||
for named_buffer in module.named_buffers(recurse=recurse):
|
||||
name, _ = named_buffer
|
||||
# Get parent by splitting on dots and traversing the model
|
||||
parent = module
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
for part in parent_name.split("."):
|
||||
parent = getattr(parent, part)
|
||||
name = name.split(".")[-1]
|
||||
if name not in parent._non_persistent_buffers_set:
|
||||
yield named_buffer
|
||||
|
||||
|
||||
def compute_module_persistent_sizes(
|
||||
model: nn.Module,
|
||||
dtype: str | torch.device | None = None,
|
||||
special_dtypes: dict[str, str | torch.device] | None = None,
|
||||
):
|
||||
"""
|
||||
Compute the size of each submodule of a given model (parameters + persistent buffers).
|
||||
"""
|
||||
if dtype is not None:
|
||||
dtype = _get_proper_dtype(dtype)
|
||||
dtype_size = dtype_byte_size(dtype)
|
||||
if special_dtypes is not None:
|
||||
special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
|
||||
module_sizes = defaultdict(int)
|
||||
|
||||
module_list = []
|
||||
|
||||
module_list = named_persistent_module_tensors(model, recurse=True)
|
||||
|
||||
for name, tensor in module_list:
|
||||
if special_dtypes is not None and name in special_dtypes:
|
||||
size = tensor.numel() * special_dtypes_size[name]
|
||||
elif dtype is None:
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
|
||||
# According to the code in set_module_tensor_to_device, these types won't be converted
|
||||
# so use their original size here
|
||||
size = tensor.numel() * dtype_byte_size(tensor.dtype)
|
||||
else:
|
||||
size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
|
||||
name_parts = name.split(".")
|
||||
for idx in range(len(name_parts) + 1):
|
||||
module_sizes[".".join(name_parts[:idx])] += size
|
||||
|
||||
return module_sizes
|
||||
|
||||
|
||||
def calculate_expected_num_shards(index_map_path):
|
||||
"""
|
||||
Calculate expected number of shards from index file.
|
||||
|
||||
Args:
|
||||
index_map_path: Path to the sharded checkpoint index file
|
||||
|
||||
Returns:
|
||||
int: Expected number of shards
|
||||
"""
|
||||
with open(index_map_path) as f:
|
||||
weight_map_dict = json.load(f)["weight_map"]
|
||||
first_key = list(weight_map_dict.keys())[0]
|
||||
weight_loc = weight_map_dict[first_key] # e.g., diffusion_pytorch_model-00001-of-00002.safetensors
|
||||
expected_num_shards = int(weight_loc.split("-")[-1].split(".")[0])
|
||||
return expected_num_shards
|
||||
|
||||
|
||||
def check_device_map_is_respected(model, device_map):
|
||||
for param_name, param in model.named_parameters():
|
||||
# Find device in device_map
|
||||
while len(param_name) > 0 and param_name not in device_map:
|
||||
param_name = ".".join(param_name.split(".")[:-1])
|
||||
if param_name not in device_map:
|
||||
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
|
||||
|
||||
param_device = device_map[param_name]
|
||||
if param_device in ["cpu", "disk"]:
|
||||
assert param.device == torch.device("meta"), f"Expected device 'meta' for {param_name}, got {param.device}"
|
||||
else:
|
||||
assert param.device == torch.device(param_device), (
|
||||
f"Expected device {param_device} for {param_name}, got {param.device}"
|
||||
)
|
||||
|
||||
|
||||
def cast_inputs_to_dtype(inputs, current_dtype, target_dtype):
|
||||
if torch.is_tensor(inputs):
|
||||
return inputs.to(target_dtype) if inputs.dtype == current_dtype else inputs
|
||||
if isinstance(inputs, dict):
|
||||
return {k: cast_inputs_to_dtype(v, current_dtype, target_dtype) for k, v in inputs.items()}
|
||||
if isinstance(inputs, list):
|
||||
return [cast_inputs_to_dtype(v, current_dtype, target_dtype) for v in inputs]
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class BaseModelTesterConfig:
|
||||
"""
|
||||
Base class defining the configuration interface for model testing.
|
||||
|
||||
This class defines the contract that all model test classes must implement.
|
||||
It provides a consistent interface for accessing model configuration, initialization
|
||||
parameters, and test inputs across all testing mixins.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties (can be overridden, have sensible defaults):
|
||||
- pretrained_model_name_or_path: Hub repository ID for pretrained model (default: None)
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (default: {})
|
||||
- output_shape: Expected output shape for output validation tests (default: None)
|
||||
- model_split_percents: Percentages for model parallelism tests (default: [0.5, 0.7])
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example usage:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return MyModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "org/my-model"
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 3, 32, 32)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {"in_channels": 3, "out_channels": 3}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
return {"sample": torch.randn(1, 3, 32, 32, device=torch_device)}
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin, QuantizationTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def model_class(self) -> Type[nn.Module]:
|
||||
"""The model class to test. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `model_class` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self) -> Optional[str]:
|
||||
"""Hub repository ID for the pretrained model (used for quantization and hub tests)."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self) -> Dict[str, Any]:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
"""Percentages for model parallelism tests."""
|
||||
return [0.9]
|
||||
|
||||
# ==================== Required Methods ====================
|
||||
|
||||
def get_init_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of arguments to initialize the model.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Initialization arguments for the model constructor.
|
||||
|
||||
Example:
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"sample_size": 32,
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_init_dict()`.")
|
||||
|
||||
def get_dummy_inputs(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns dict of inputs to pass to the model forward pass.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Input tensors/values for model.forward().
|
||||
|
||||
Example:
|
||||
return {
|
||||
"sample": torch.randn(1, 3, 32, 32, device=torch_device),
|
||||
"timestep": torch.tensor([1], device=torch_device),
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement `get_dummy_inputs()`.")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
"""
|
||||
Base mixin class for model testing with common test methods.
|
||||
|
||||
This mixin expects the test class to also inherit from BaseModelTesterConfig
|
||||
(or implement its interface) which provides:
|
||||
- model_class: The model class to test
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Example:
|
||||
class MyModelTestConfig(BaseModelTesterConfig):
|
||||
model_class = MyModel
|
||||
def get_init_dict(self): ...
|
||||
def get_dummy_inputs(self): ...
|
||||
|
||||
class TestMyModel(MyModelTestConfig, ModelTesterMixin):
|
||||
pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained(self, tmp_path, atol=5e-5, rtol=5e-5):
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path)
|
||||
new_model.to(torch_device)
|
||||
|
||||
for param_name in model.state_dict().keys():
|
||||
param_1 = model.state_dict()[param_name]
|
||||
param_2 = new_model.state_dict()[param_name]
|
||||
assert param_1.shape == param_2.shape, (
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_variant(self, tmp_path, atol=5e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.save_pretrained(tmp_path, variant="fp16")
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant="fp16")
|
||||
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
self.model_class.from_pretrained(tmp_path)
|
||||
|
||||
assert "Error no file named diffusion_pytorch_model.bin found in directory" in str(exc_info.value)
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
pytest.skip(reason=f"{dtype} is not supported on {torch_device}")
|
||||
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmp_path)
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
if hasattr(self.model_class, "_keep_in_fp32_modules") and self.model_class._keep_in_fp32_modules is None:
|
||||
# When loading without accelerate dtype == torch.float32 if _keep_in_fp32_modules is not None
|
||||
new_model = self.model_class.from_pretrained(tmp_path, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
@torch.no_grad()
|
||||
def test_determinism(self, atol=1e-5, rtol=0):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
mask = ~(torch.isnan(first_flat) | torch.isnan(second_flat))
|
||||
first_filtered = first_flat[mask]
|
||||
second_filtered = second_flat[mask]
|
||||
|
||||
assert_tensors_close(
|
||||
first_filtered, second_filtered, atol=atol, rtol=rtol, msg="Model outputs are not deterministic"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self, expected_output_shape=None):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
assert output[0].shape == expected_output_shape or self.output_shape, (
|
||||
f"Output shape does not match expected. Expected {expected_output_shape}, got {output.shape}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_outputs_equivalence(self, atol=1e-5, rtol=0):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
device = t.device
|
||||
if device.type == "mps":
|
||||
t = t.to("cpu")
|
||||
t[t != t] = 0
|
||||
return t.to(device)
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (list, tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
assert_tensors_close(
|
||||
set_nan_tensor_to_zero(tuple_object),
|
||||
set_nan_tensor_to_zero(dict_object),
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Tuple and dict output are not equal",
|
||||
)
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_getattr_is_correct(self, caplog):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
# save some things to test
|
||||
model.dummy_attribute = 5
|
||||
model.register_to_config(test_attribute=5)
|
||||
|
||||
logger_name = "diffusers.models.modeling_utils"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "dummy_attribute")
|
||||
assert getattr(model, "dummy_attribute") == 5
|
||||
assert model.dummy_attribute == 5
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
caplog.clear()
|
||||
assert hasattr(model, "save_pretrained")
|
||||
fn = model.save_pretrained
|
||||
fn_1 = getattr(model, "save_pretrained")
|
||||
|
||||
assert fn == fn_1
|
||||
|
||||
# no warning should be thrown
|
||||
assert caplog.text == ""
|
||||
|
||||
# warning should be thrown for config attributes accessed directly
|
||||
with pytest.warns(FutureWarning):
|
||||
assert model.test_attribute == 5
|
||||
|
||||
with pytest.warns(FutureWarning):
|
||||
assert getattr(model, "test_attribute") == 5
|
||||
|
||||
with pytest.raises(AttributeError) as error:
|
||||
model.does_not_exist
|
||||
|
||||
assert str(error.value) == f"'{type(model).__name__}' object has no attribute 'does_not_exist'"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
if fp32_modules is None or len(fp32_modules) == 0:
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||
|
||||
# Test with float16
|
||||
model.to(torch_device)
|
||||
model.to(torch.float16)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
assert param.dtype == torch.float32, f"Parameter {name} should be float32 but got {param.dtype}"
|
||||
else:
|
||||
assert param.dtype == torch.float16, f"Parameter {name} should be float16 but got {param.dtype}"
|
||||
|
||||
@require_accelerator
|
||||
@pytest.mark.skipif(
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be use for inference with an accelerator",
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules or []
|
||||
|
||||
model.to(dtype).save_pretrained(tmp_path)
|
||||
model_loaded = self.model_class.from_pretrained(tmp_path, torch_dtype=dtype).to(torch_device)
|
||||
|
||||
for name, param in model_loaded.named_parameters():
|
||||
if fp32_modules and any(
|
||||
module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules
|
||||
):
|
||||
assert param.data.dtype == torch.float32
|
||||
else:
|
||||
assert param.data.dtype == dtype
|
||||
|
||||
inputs = cast_inputs_to_dtype(self.get_dummy_inputs(), torch.float32, dtype)
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
output_loaded = model_loaded(**inputs, return_dict=False)[0]
|
||||
|
||||
self._check_dtype_inference_output(output, output_loaded, dtype)
|
||||
|
||||
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
|
||||
"""Check dtype inference output with configurable tolerance."""
|
||||
assert_tensors_close(
|
||||
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints(self, tmp_path, atol=1e-5, rtol=0):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
)
|
||||
|
||||
@require_accelerator
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints_with_variant(self, tmp_path, atol=1e-5, rtol=0):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
variant = "fp16"
|
||||
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB", variant=variant)
|
||||
|
||||
index_filename = _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
|
||||
assert os.path.exists(os.path.join(tmp_path, index_filename)), (
|
||||
f"Variant index file {index_filename} should exist"
|
||||
)
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, index_filename))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
new_model = self.model_class.from_pretrained(tmp_path, variant=variant).eval()
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_sharded_checkpoints_with_parallel_loading(self, tmp_path, atol=1e-5, rtol=0):
|
||||
from diffusers.utils import constants
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small
|
||||
|
||||
# Save original values to restore after test
|
||||
original_parallel_loading = constants.HF_ENABLE_PARALLEL_LOADING
|
||||
original_parallel_workers = getattr(constants, "HF_PARALLEL_WORKERS", None)
|
||||
|
||||
try:
|
||||
model.cpu().save_pretrained(tmp_path, max_shard_size=f"{max_shard_size}KB")
|
||||
assert os.path.exists(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME)), "Index file should exist"
|
||||
|
||||
# Check if the right number of shards exists
|
||||
expected_num_shards = calculate_expected_num_shards(os.path.join(tmp_path, SAFE_WEIGHTS_INDEX_NAME))
|
||||
actual_num_shards = len([file for file in os.listdir(tmp_path) if file.endswith(".safetensors")])
|
||||
assert actual_num_shards == expected_num_shards, (
|
||||
f"Expected {expected_num_shards} shards, got {actual_num_shards}"
|
||||
)
|
||||
|
||||
# Load without parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = False
|
||||
model_sequential = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_sequential = model_sequential.to(torch_device)
|
||||
|
||||
# Load with parallel loading
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = True
|
||||
constants.DEFAULT_HF_PARALLEL_LOADING_WORKERS = 2
|
||||
|
||||
torch.manual_seed(0)
|
||||
model_parallel = self.model_class.from_pretrained(tmp_path).eval()
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Restore original values
|
||||
constants.HF_ENABLE_PARALLEL_LOADING = original_parallel_loading
|
||||
if original_parallel_workers is not None:
|
||||
constants.HF_PARALLEL_WORKERS = original_parallel_workers
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
@torch.no_grad()
|
||||
def test_model_parallelism(self, tmp_path, atol=1e-5, rtol=0):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
|
||||
model.cpu().save_pretrained(tmp_path)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_path, device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will be on GPU 0 and GPU 1
|
||||
assert set(new_model.hf_device_map.values()) == {0, 1}, "Model should be split across GPUs"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match with model parallelism"
|
||||
)
|
||||
@@ -1,166 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_torch_compile,
|
||||
require_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_torch_compile
|
||||
@require_accelerator
|
||||
@require_torch_version_greater("2.7.1")
|
||||
class TorchCompileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing torch.compile functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic shape testing (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: compile
|
||||
Use `pytest -m "not compile"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic shape testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_with_group_offloading(self):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
group_offload_kwargs = {
|
||||
"onload_device": torch_device,
|
||||
"offload_device": "cpu",
|
||||
"offload_type": "block_level",
|
||||
"num_blocks_per_group": 1,
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
}
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
model.compile()
|
||||
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_on_different_shapes(self):
|
||||
if self.different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model = torch.compile(model, fullgraph=True, dynamic=True)
|
||||
|
||||
for height, width in self.different_shapes_for_compilation:
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
inputs_dict = self.get_dummy_inputs(height=height, width=width)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
@torch.no_grad()
|
||||
def test_compile_works_with_aot(self, tmp_path):
|
||||
from torch._inductor.package import load_package
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
exported_model = torch.export.export(model, args=(), kwargs=inputs_dict)
|
||||
|
||||
package_path = os.path.join(str(tmp_path), f"{self.model_class.__name__}.pt2")
|
||||
_ = torch._inductor.aoti_compile_and_package(exported_model, package_path=package_path)
|
||||
assert os.path.exists(package_path), f"Package file not created at {package_path}"
|
||||
loaded_binary = load_package(package_path, run_single_threaded=True)
|
||||
|
||||
model.forward = loaded_binary
|
||||
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
@@ -1,158 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ...testing_utils import backend_empty_cache, is_ip_adapter, torch_device
|
||||
|
||||
|
||||
def check_if_ip_adapter_correctly_set(model, processor_cls) -> bool:
|
||||
"""
|
||||
Check if IP Adapter processors are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if IP Adapter is correctly set, False otherwise
|
||||
"""
|
||||
for module in model.attn_processors.values():
|
||||
if isinstance(module, processor_cls):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_ip_adapter
|
||||
class IPAdapterTesterMixin:
|
||||
"""
|
||||
Mixin class for testing IP Adapter functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ip_adapter_processor_cls: The IP Adapter processor class to use
|
||||
|
||||
Required methods (must be implemented by subclasses):
|
||||
- create_ip_adapter_state_dict(): Creates IP Adapter state dict for testing
|
||||
- modify_inputs_for_ip_adapter(): Modifies inputs to include IP Adapter data
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: ip_adapter
|
||||
Use `pytest -m "not ip_adapter"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
"""IP Adapter processor class to use for testing. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ip_adapter_processor_cls` property.")
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter State Dict")
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
raise NotImplementedError("child class must implement method to create IPAdapter model inputs")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_load_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_adapter = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter processors not set correctly"
|
||||
)
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
outputs_with_adapter = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_adapter, outputs_with_adapter, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with IP Adapter enabled"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Setting IP Adapter scale is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_ip_adapter_scale(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
inputs_dict_with_adapter = self.modify_inputs_for_ip_adapter(model, inputs_dict.copy())
|
||||
|
||||
# Test scale = 0.0 (no effect)
|
||||
model.set_ip_adapter_scale(0.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_zero = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Test scale = 1.0 (full effect)
|
||||
model.set_ip_adapter_scale(1.0)
|
||||
torch.manual_seed(0)
|
||||
output_scale_one = model(**inputs_dict_with_adapter, return_dict=False)[0]
|
||||
|
||||
# Outputs should differ with different scales
|
||||
assert not torch.allclose(output_scale_zero, output_scale_one, atol=1e-4, rtol=1e-4), (
|
||||
"Output should differ with different IP Adapter scales"
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unloading IP Adapter is not defined at the model level. Enable this test after refactoring"
|
||||
)
|
||||
def test_unload_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Save original processors
|
||||
original_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
|
||||
# Create and load IP adapter
|
||||
ip_adapter_state_dict = self.create_ip_adapter_state_dict(model)
|
||||
model._load_ip_adapter_weights([ip_adapter_state_dict])
|
||||
|
||||
assert check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), "IP Adapter should be set"
|
||||
|
||||
# Unload IP adapter
|
||||
model.unload_ip_adapter()
|
||||
|
||||
assert not check_if_ip_adapter_correctly_set(model, self.ip_adapter_processor_cls), (
|
||||
"IP Adapter should be unloaded"
|
||||
)
|
||||
|
||||
# Verify processors are restored
|
||||
current_processors = {k: type(v).__name__ for k, v in model.attn_processors.items()}
|
||||
assert original_processors == current_processors, "Processors should be restored after unload"
|
||||
@@ -1,555 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from diffusers.utils.import_utils import is_peft_available
|
||||
from diffusers.utils.testing_utils import check_if_dicts_are_equal
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_lora,
|
||||
is_torch_compile,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
require_torch_accelerator,
|
||||
require_torch_version_greater,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from diffusers.loaders.peft import PeftAdapterMixin
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Check if LoRA layers are correctly set in the model.
|
||||
|
||||
Args:
|
||||
model: The model to check
|
||||
|
||||
Returns:
|
||||
bool: True if LoRA is correctly set, False otherwise
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@is_lora
|
||||
@require_peft_backend
|
||||
class LoraTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA/PEFT functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: lora
|
||||
Use `pytest -m "not lora"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
@torch.no_grad()
|
||||
def test_save_load_lora_adapter(self, tmp_path, rank=4, lora_alpha=4, use_dora=False, atol=1e-4, rtol=1e-4):
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora, atol=atol, rtol=rtol), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
assert os.path.isfile(os.path.join(tmp_path, "pytorch_lora_weights.safetensors")), (
|
||||
"LoRA weights file not created"
|
||||
)
|
||||
|
||||
state_dict_loaded = safetensors.torch.load_file(os.path.join(tmp_path, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0")
|
||||
|
||||
for k in state_dict_loaded:
|
||||
loaded_v = state_dict_loaded[k]
|
||||
retrieved_v = state_dict_retrieved[k].to(loaded_v.device)
|
||||
assert_tensors_close(loaded_v, retrieved_v, atol=atol, rtol=rtol, msg=f"Mismatch in LoRA weight {k}")
|
||||
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly after reload"
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=atol, rtol=rtol), (
|
||||
"Output should differ with LoRA enabled"
|
||||
)
|
||||
assert_tensors_close(
|
||||
outputs_with_lora,
|
||||
outputs_with_lora_2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Outputs should match before and after save/load",
|
||||
)
|
||||
|
||||
def test_lora_wrong_adapter_name_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
wrong_name = "foo"
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
model.save_lora_adapter(tmp_path, adapter_name=wrong_name)
|
||||
|
||||
assert f"Adapter name {wrong_name} not found in the model." in str(exc_info.value)
|
||||
|
||||
def test_lora_adapter_metadata_is_loaded_correctly(self, tmp_path, rank=4, lora_alpha=4, use_dora=False):
|
||||
from peft import LoraConfig
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
metadata = model.peft_config["default"].to_dict()
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
parsed_metadata = model.peft_config["default_0"].to_dict()
|
||||
check_if_dicts_are_equal(metadata, parsed_metadata)
|
||||
|
||||
def test_lora_adapter_wrong_metadata_raises_error(self, tmp_path):
|
||||
from peft import LoraConfig
|
||||
|
||||
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
model.add_adapter(denoiser_lora_config)
|
||||
assert check_if_lora_correctly_set(model), "LoRA layers not set correctly"
|
||||
|
||||
model.save_lora_adapter(tmp_path)
|
||||
model_file = os.path.join(tmp_path, "pytorch_lora_weights.safetensors")
|
||||
assert os.path.isfile(model_file), "LoRA weights file not created"
|
||||
|
||||
# Perturb the metadata in the state dict
|
||||
loaded_state_dict = safetensors.torch.load_file(model_file)
|
||||
metadata = {"format": "pt"}
|
||||
lora_adapter_metadata = denoiser_lora_config.to_dict()
|
||||
lora_adapter_metadata.update({"foo": 1, "bar": 2})
|
||||
for key, value in lora_adapter_metadata.items():
|
||||
if isinstance(value, set):
|
||||
lora_adapter_metadata[key] = list(value)
|
||||
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
|
||||
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
|
||||
|
||||
model.unload_lora()
|
||||
assert not check_if_lora_correctly_set(model), "LoRA should be unloaded"
|
||||
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
model.load_lora_adapter(tmp_path, prefix=None, use_safetensors=True)
|
||||
assert "`LoraConfig` class could not be instantiated" in str(exc_info.value)
|
||||
|
||||
|
||||
@is_lora
|
||||
@is_torch_compile
|
||||
@require_peft_backend
|
||||
@require_peft_version_greater("0.14.0")
|
||||
@require_torch_version_greater("2.7.1")
|
||||
@require_torch_accelerator
|
||||
class LoraHotSwappingForModelTesterMixin:
|
||||
"""
|
||||
Mixin class for testing LoRA hot swapping functionality on models.
|
||||
|
||||
Test that hotswapping does not result in recompilation on the model directly.
|
||||
We're not extensively testing the hotswapping functionality since it is implemented in PEFT
|
||||
and is extensively tested there. The goal of this test is specifically to ensure that
|
||||
hotswapping with diffusers does not require recompilation.
|
||||
|
||||
See https://github.com/huggingface/peft/blob/eaab05e18d51fb4cce20a73c9acd82a00c013b83/tests/test_gpu_examples.py#L4252
|
||||
for the analogous PEFT test.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- different_shapes_for_compilation: List of (height, width) tuples for dynamic compilation tests (default: None)
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest marks: lora, torch_compile
|
||||
Use `pytest -m "not lora"` or `pytest -m "not torch_compile"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self) -> list[tuple[int, int]] | None:
|
||||
"""Optional list of (height, width) tuples for dynamic compilation tests."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
if not issubclass(self.model_class, PeftAdapterMixin):
|
||||
pytest.skip(f"PEFT is not supported for this model ({self.model_class.__name__}).")
|
||||
|
||||
def teardown_method(self):
|
||||
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
|
||||
# there will be recompilation errors, as torch caches the model when run in the same process.
|
||||
torch.compiler.reset()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def _get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
target_modules=target_modules,
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def _get_linear_module_name_other_than_attn(self, model):
|
||||
linear_names = [
|
||||
name for name, module in model.named_modules() if isinstance(module, nn.Linear) and "to_" not in name
|
||||
]
|
||||
return linear_names[0]
|
||||
|
||||
def _check_model_hotswap(
|
||||
self, tmp_path, do_compile, rank0, rank1, target_modules0, target_modules1=None, atol=5e-3, rtol=5e-3
|
||||
):
|
||||
"""
|
||||
Check that hotswapping works on a model.
|
||||
|
||||
Steps:
|
||||
- create 2 LoRA adapters and save them
|
||||
- load the first adapter
|
||||
- hotswap the second adapter
|
||||
- check that the outputs are correct
|
||||
- optionally compile the model
|
||||
- optionally check if recompilations happen on different shapes
|
||||
|
||||
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
|
||||
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
|
||||
fine.
|
||||
"""
|
||||
different_shapes = self.different_shapes_for_compilation
|
||||
# create 2 adapters with different ranks and alphas
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
alpha0, alpha1 = rank0, rank1
|
||||
max_rank = max([rank0, rank1])
|
||||
if target_modules1 is None:
|
||||
target_modules1 = target_modules0[:]
|
||||
lora_config0 = self._get_lora_config(rank0, alpha0, target_modules0)
|
||||
lora_config1 = self._get_lora_config(rank1, alpha1, target_modules1)
|
||||
|
||||
model.add_adapter(lora_config0, adapter_name="adapter0")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output0_before = model(**inputs_dict)["sample"]
|
||||
|
||||
model.add_adapter(lora_config1, adapter_name="adapter1")
|
||||
model.set_adapter("adapter1")
|
||||
with torch.inference_mode():
|
||||
torch.manual_seed(0)
|
||||
output1_before = model(**inputs_dict)["sample"]
|
||||
|
||||
# sanity checks:
|
||||
assert not torch.allclose(output0_before, output1_before, atol=atol, rtol=rtol)
|
||||
assert not (output0_before == 0).all()
|
||||
assert not (output1_before == 0).all()
|
||||
|
||||
# save the adapter checkpoints
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "0"), safe_serialization=True, adapter_name="adapter0")
|
||||
model.save_lora_adapter(os.path.join(tmp_path, "1"), safe_serialization=True, adapter_name="adapter1")
|
||||
del model
|
||||
|
||||
# load the first adapter
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if do_compile or (rank0 != rank1):
|
||||
# no need to prepare if the model is not compiled or if the ranks are identical
|
||||
model.enable_lora_hotswap(target_rank=max_rank)
|
||||
|
||||
file_name0 = os.path.join(os.path.join(tmp_path, "0"), "pytorch_lora_weights.safetensors")
|
||||
file_name1 = os.path.join(os.path.join(tmp_path, "1"), "pytorch_lora_weights.safetensors")
|
||||
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
|
||||
|
||||
if do_compile:
|
||||
model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
|
||||
|
||||
with torch.inference_mode():
|
||||
# additionally check if dynamic compilation works.
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output0_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output0_before, output0_after, atol=atol, rtol=rtol, msg="Output mismatch after loading adapter0"
|
||||
)
|
||||
|
||||
# hotswap the 2nd adapter
|
||||
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
|
||||
|
||||
# we need to call forward to potentially trigger recompilation
|
||||
with torch.inference_mode():
|
||||
if different_shapes is not None:
|
||||
for height, width in different_shapes:
|
||||
new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
|
||||
_ = model(**new_inputs_dict)
|
||||
else:
|
||||
output1_after = model(**inputs_dict)["sample"]
|
||||
assert_tensors_close(
|
||||
output1_before,
|
||||
output1_after,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output mismatch after hotswapping to adapter1",
|
||||
)
|
||||
|
||||
# check error when not passing valid adapter name
|
||||
name = "does-not-exist"
|
||||
msg = f"Trying to hotswap LoRA adapter '{name}' but there is no existing adapter by that name"
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
model.load_lora_adapter(file_name1, adapter_name=name, hotswap=True, prefix=None)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_model(self, tmp_path, rank0, rank1):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=False, rank0=rank0, rank1=rank1, target_modules0=["to_q", "to_k", "to_v", "to_out.0"]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_linear(self, tmp_path, rank0, rank1):
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["conv", "conv1", "conv2"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, tmp_path, rank0, rank1):
|
||||
if "unet" not in self.model_class.__name__.lower():
|
||||
pytest.skip("Test only applies to UNet.")
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
target_modules = ["to_q", "conv"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True), torch._inductor.utils.fresh_inductor_cache():
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self, tmp_path, rank0, rank1):
|
||||
# In `test_hotswapping_compiled_model_both_linear_and_conv2d()`, we check if we can do hotswapping
|
||||
# with `torch.compile()` for models that have both linear and conv layers. In this test, we check
|
||||
# if we can target a linear layer from the transformer blocks and another linear layer from non-attention
|
||||
# block.
|
||||
target_modules = ["to_q"]
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
target_modules.append(self._get_linear_module_name_other_than_attn(model))
|
||||
del model
|
||||
|
||||
# It's important to add this context to raise an error on recompilation
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path, do_compile=True, rank0=rank0, rank1=rank1, target_modules0=target_modules
|
||||
)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_raises(self):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
|
||||
msg = re.escape("Call `enable_lora_hotswap` before loading the first adapter.")
|
||||
with pytest.raises(RuntimeError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
msg = re.escape("check_compiles should be one of 'error', 'warn', or 'ignore', got 'wrong-argument' instead.")
|
||||
with pytest.raises(ValueError, match=msg):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
|
||||
|
||||
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
def test_hotswapping_compile_on_different_shapes(self, tmp_path, rank0, rank1):
|
||||
different_shapes_for_compilation = self.different_shapes_for_compilation
|
||||
if different_shapes_for_compilation is None:
|
||||
pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
|
||||
# Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
|
||||
# variable to represent input sizes that are the same. For more details,
|
||||
# check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
|
||||
torch.fx.experimental._config.use_duck_shape = False
|
||||
|
||||
target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
||||
with torch._dynamo.config.patch(error_on_recompile=True):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=rank0,
|
||||
rank1=rank1,
|
||||
target_modules0=target_modules,
|
||||
)
|
||||
@@ -1,498 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import glob
|
||||
import inspect
|
||||
from functools import wraps
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from accelerate.utils.modeling import compute_module_sizes
|
||||
|
||||
from diffusers.utils.testing_utils import _check_safetensors_serialization
|
||||
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_peak_memory_stats,
|
||||
backend_synchronize,
|
||||
is_cpu_offload,
|
||||
is_group_offload,
|
||||
is_memory,
|
||||
require_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import cast_inputs_to_dtype, check_device_map_is_respected
|
||||
|
||||
|
||||
def require_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support offloading (requires _no_split_modules).
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_group_offload_support(func):
|
||||
"""
|
||||
Decorator to skip tests if model doesn't support group offloading.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@is_cpu_offload
|
||||
class CPUOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing CPU offloading functionality.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: cpu_offload
|
||||
Use `pytest -m "not cpu_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list[float]:
|
||||
"""List of percentages for splitting model across devices during offloading tests."""
|
||||
return [0.5, 0.7]
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_cpu_offload(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
assert set(new_model.hf_device_map.values()) == {0, "cpu"}, "Model should be split between GPU and CPU"
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with CPU offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_disk_offload_without_safetensors(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
# Force disk offload by setting very small CPU memory
|
||||
max_memory = {0: max_size, "cpu": int(0.1 * max_size)}
|
||||
|
||||
model.cpu().save_pretrained(str(tmp_path), safe_serialization=False)
|
||||
# This errors out because it's missing an offload folder
|
||||
with pytest.raises(ValueError):
|
||||
new_model = self.model_class.from_pretrained(str(tmp_path), device_map="auto", max_memory=max_memory)
|
||||
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", max_memory=max_memory, offload_folder=str(tmp_path)
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0], new_output[0], atol=atol, rtol=rtol, msg="Output should match with disk offloading"
|
||||
)
|
||||
|
||||
@require_offload_support
|
||||
@torch.no_grad()
|
||||
def test_disk_offload_with_safetensors(self, tmp_path, atol=1e-5, rtol=0):
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
model.cpu().save_pretrained(str(tmp_path))
|
||||
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
max_memory = {0: max_size, "cpu": max_size}
|
||||
new_model = self.model_class.from_pretrained(
|
||||
str(tmp_path), device_map="auto", offload_folder=str(tmp_path), max_memory=max_memory
|
||||
)
|
||||
|
||||
check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
assert_tensors_close(
|
||||
base_output[0],
|
||||
new_output[0],
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with disk offloading (safetensors)",
|
||||
)
|
||||
|
||||
|
||||
@is_group_offload
|
||||
class GroupOffloadTesterMixin:
|
||||
"""
|
||||
Mixin class for testing group offloading functionality.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: group_offload
|
||||
Use `pytest -m "not group_offload"` to skip these tests
|
||||
"""
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
def test_group_offloading(self, record_stream, atol=1e-5, rtol=0):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
), "Group offloading hook should be set"
|
||||
model.eval()
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading1,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading2,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with non-blocking block-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading3,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with leaf-level offloading",
|
||||
)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading4,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with leaf-level offloading with stream",
|
||||
)
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
storage_dtype, compute_dtype = torch.float16, torch.float32
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
|
||||
model.enable_group_offload(
|
||||
torch_device, offload_type=offload_type, use_stream=True, record_stream=record_stream, **additional_kwargs
|
||||
)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
_ = model(**inputs_dict)[0]
|
||||
|
||||
@require_group_offload_support
|
||||
@pytest.mark.parametrize("record_stream", [False, True])
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, tmp_path, record_stream, offload_type, atol=1e-5, rtol=0):
|
||||
def _has_generator_arg(model):
|
||||
sig = inspect.signature(model.forward)
|
||||
params = sig.parameters
|
||||
return "generator" in params
|
||||
|
||||
def _run_forward(model, inputs_dict):
|
||||
accepts_generator = _has_generator_arg(model)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
torch.manual_seed(0)
|
||||
return model(**inputs_dict)[0]
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.eval()
|
||||
|
||||
num_blocks_per_group = None if offload_type == "leaf_level" else 1
|
||||
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
|
||||
tmpdir = str(tmp_path)
|
||||
model.enable_group_offload(
|
||||
torch_device,
|
||||
offload_type=offload_type,
|
||||
offload_to_disk_path=tmpdir,
|
||||
use_stream=True,
|
||||
record_stream=record_stream,
|
||||
**additional_kwargs,
|
||||
)
|
||||
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
|
||||
assert has_safetensors, "No safetensors found in the directory."
|
||||
|
||||
# For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
|
||||
# in nature. So, skip it.
|
||||
if offload_type != "leaf_level":
|
||||
is_correct, extra_files, missing_files = _check_safetensors_serialization(
|
||||
module=model,
|
||||
offload_to_disk_path=tmpdir,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
)
|
||||
if not is_correct:
|
||||
if extra_files:
|
||||
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
|
||||
elif missing_files:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
assert_tensors_close(
|
||||
output_without_group_offloading,
|
||||
output_with_group_offloading,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
msg="Output should match with disk-based group offloading",
|
||||
)
|
||||
|
||||
|
||||
class LayerwiseCastingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing layerwise dtype casting for memory optimization.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def test_layerwise_casting_memory(self):
|
||||
MB_TOLERANCE = 0.2
|
||||
LEAST_COMPUTE_CAPABILITY = 8.0
|
||||
|
||||
def reset_memory_stats():
|
||||
gc.collect()
|
||||
backend_synchronize(torch_device)
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
|
||||
def get_memory_usage(storage_dtype, compute_dtype):
|
||||
torch.manual_seed(0)
|
||||
config = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
model = self.model_class(**config).eval()
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
reset_memory_stats()
|
||||
model(**inputs_dict)
|
||||
model_memory_footprint = model.get_memory_footprint()
|
||||
peak_inference_memory_allocated_mb = backend_max_memory_allocated(torch_device) / 1024**2
|
||||
|
||||
return model_memory_footprint, peak_inference_memory_allocated_mb
|
||||
|
||||
fp32_memory_footprint, fp32_max_memory = get_memory_usage(torch.float32, torch.float32)
|
||||
fp8_e4m3_fp32_memory_footprint, fp8_e4m3_fp32_max_memory = get_memory_usage(torch.float8_e4m3fn, torch.float32)
|
||||
fp8_e4m3_bf16_memory_footprint, fp8_e4m3_bf16_max_memory = get_memory_usage(
|
||||
torch.float8_e4m3fn, torch.bfloat16
|
||||
)
|
||||
|
||||
compute_capability = get_torch_cuda_device_capability() if torch_device == "cuda" else None
|
||||
assert fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint, (
|
||||
"Memory footprint should decrease with lower precision storage"
|
||||
)
|
||||
|
||||
# NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
|
||||
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
|
||||
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
|
||||
assert fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory, (
|
||||
"Peak memory should be lower with bf16 compute on newer GPUs"
|
||||
)
|
||||
|
||||
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
|
||||
# bytes. This only happens for some models, so we allow a small tolerance.
|
||||
# For any real model being tested, the order would be fp8_e4m3_bf16 < fp8_e4m3_fp32 < fp32.
|
||||
assert (
|
||||
fp8_e4m3_fp32_max_memory < fp32_max_memory
|
||||
or abs(fp8_e4m3_fp32_max_memory - fp32_max_memory) < MB_TOLERANCE
|
||||
), "Peak memory should be lower or within tolerance with fp8 storage"
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
def test_fn(storage_dtype, compute_dtype):
|
||||
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
|
||||
pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
|
||||
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model = model.to(torch_device, dtype=compute_dtype)
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
model.train()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = cast_inputs_to_dtype(inputs_dict, torch.float32, compute_dtype)
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
|
||||
noise = cast_inputs_to_dtype(noise, torch.float32, compute_dtype)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
test_fn(torch.float16, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.float32)
|
||||
test_fn(torch.float8_e5m2, torch.float32)
|
||||
test_fn(torch.float8_e4m3fn, torch.bfloat16)
|
||||
|
||||
|
||||
@is_memory
|
||||
@require_accelerator
|
||||
class MemoryTesterMixin(CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin):
|
||||
"""
|
||||
Combined mixin class for all memory optimization tests including CPU/disk offloading,
|
||||
group offloading, and layerwise dtype casting.
|
||||
|
||||
This mixin inherits from:
|
||||
- CPUOffloadTesterMixin: CPU and disk offloading tests
|
||||
- GroupOffloadTesterMixin: Group offloading tests (block-level and leaf-level)
|
||||
- LayerwiseCastingTesterMixin: Layerwise dtype casting tests
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
|
||||
Optional properties:
|
||||
- model_split_percents: List of percentages for splitting model across devices (default: [0.5, 0.7])
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: memory
|
||||
Use `pytest -m "not memory"` to skip these tests
|
||||
"""
|
||||
@@ -1,128 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
|
||||
"""Worker function for context parallel testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
# Create model
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
# Run forward pass
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
# Only rank 0 reports results
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_worker,
|
||||
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,272 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_single_file,
|
||||
nightly,
|
||||
require_torch_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
from .common import check_device_map_is_respected
|
||||
|
||||
|
||||
def download_single_file_checkpoint(pretrained_model_name_or_path, filename, tmpdir):
|
||||
"""Download a single file checkpoint from the Hub to a temporary directory."""
|
||||
path = hf_hub_download(pretrained_model_name_or_path, filename=filename, local_dir=tmpdir)
|
||||
return path
|
||||
|
||||
|
||||
def download_diffusers_config(pretrained_model_name_or_path, tmpdir):
|
||||
"""Download diffusers config files (excluding weights) from a repository."""
|
||||
path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
ignore_patterns=[
|
||||
"**/*.ckpt",
|
||||
"*.ckpt",
|
||||
"**/*.bin",
|
||||
"*.bin",
|
||||
"**/*.pt",
|
||||
"*.pt",
|
||||
"**/*.safetensors",
|
||||
"*.safetensors",
|
||||
],
|
||||
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"],
|
||||
local_dir=tmpdir,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@is_single_file
|
||||
class SingleFileTesterMixin:
|
||||
"""
|
||||
Mixin class for testing single file loading for models.
|
||||
|
||||
Required properties (must be implemented by subclasses):
|
||||
- ckpt_path: Path or Hub path to the single file checkpoint
|
||||
|
||||
Optional properties:
|
||||
- torch_dtype: torch dtype to use for testing (default: None)
|
||||
- alternate_ckpt_paths: List of alternate checkpoint paths for variant testing (default: None)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- pretrained_model_name_or_path: Hub repository ID for the pretrained model
|
||||
- pretrained_model_kwargs: Additional kwargs for from_pretrained (e.g., subfolder)
|
||||
|
||||
Pytest mark: single_file
|
||||
Use `pytest -m "not single_file"` to skip these tests
|
||||
"""
|
||||
|
||||
# ==================== Required Properties ====================
|
||||
|
||||
@property
|
||||
def ckpt_path(self) -> str:
|
||||
"""Path or Hub path to the single file checkpoint. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("Subclasses must implement the `ckpt_path` property.")
|
||||
|
||||
# ==================== Optional Properties ====================
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype | None:
|
||||
"""torch dtype to use for single file testing."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self) -> list[str] | None:
|
||||
"""List of alternate checkpoint paths for variant testing."""
|
||||
return None
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_model_config(self):
|
||||
pretrained_kwargs = {"device": torch_device, **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs between pretrained loading and single file loading: "
|
||||
f"pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_model_parameters(self):
|
||||
pretrained_kwargs = {"device_map": str(torch_device), **self.pretrained_model_kwargs}
|
||||
single_file_kwargs = {"device": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
# Load pretrained model, get state dict on CPU, then free GPU memory
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Load single file model, get state dict on CPU
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
state_dict_single_file = {k: v.cpu() for k, v in model_single_file.state_dict().items()}
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
|
||||
"Model parameters keys differ between pretrained and single file loading. "
|
||||
f"Missing in single file: {set(state_dict.keys()) - set(state_dict_single_file.keys())}. "
|
||||
f"Extra in single file: {set(state_dict_single_file.keys()) - set(state_dict.keys())}"
|
||||
)
|
||||
|
||||
for key in state_dict.keys():
|
||||
param = state_dict[key]
|
||||
param_single_file = state_dict_single_file[key]
|
||||
|
||||
assert param.shape == param_single_file.shape, (
|
||||
f"Parameter shape mismatch for {key}: "
|
||||
f"pretrained {param.shape} vs single file {param_single_file.shape}"
|
||||
)
|
||||
|
||||
assert torch.equal(param, param_single_file), f"Parameter values differ for {key}"
|
||||
|
||||
def test_single_file_loading_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with local_files_only=True"
|
||||
|
||||
def test_single_file_loading_with_diffusers_config(self):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs.update(self.pretrained_model_kwargs)
|
||||
|
||||
# Load with config parameter
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
self.ckpt_path, config=self.pretrained_model_name_or_path, **single_file_kwargs
|
||||
)
|
||||
|
||||
# Load pretrained for comparison
|
||||
pretrained_kwargs = {**self.pretrained_model_kwargs}
|
||||
if self.torch_dtype:
|
||||
pretrained_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_pretrained(self.pretrained_model_name_or_path, **pretrained_kwargs)
|
||||
|
||||
# Compare configs
|
||||
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
|
||||
for param_name, param_value in model_single_file.config.items():
|
||||
if param_name in PARAMS_TO_IGNORE:
|
||||
continue
|
||||
assert model.config[param_name] == param_value, (
|
||||
f"{param_name} differs: pretrained={model.config[param_name]}, single_file={param_value}"
|
||||
)
|
||||
|
||||
def test_single_file_loading_with_diffusers_config_local_files_only(self, tmp_path):
|
||||
single_file_kwargs = {}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
single_file_kwargs.update(self.pretrained_model_kwargs)
|
||||
|
||||
pretrained_model_name_or_path, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path)
|
||||
local_ckpt_path = download_single_file_checkpoint(pretrained_model_name_or_path, weight_name, str(tmp_path))
|
||||
local_diffusers_config = download_diffusers_config(self.pretrained_model_name_or_path, str(tmp_path))
|
||||
|
||||
model_single_file = self.model_class.from_single_file(
|
||||
local_ckpt_path, config=local_diffusers_config, local_files_only=True, **single_file_kwargs
|
||||
)
|
||||
|
||||
assert model_single_file is not None, "Failed to load model with config and local_files_only=True"
|
||||
|
||||
def test_single_file_loading_dtype(self):
|
||||
for dtype in [torch.float32, torch.float16]:
|
||||
if torch_device == "mps" and dtype == torch.bfloat16:
|
||||
continue
|
||||
|
||||
model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=dtype)
|
||||
|
||||
assert model_single_file.dtype == dtype, f"Expected dtype {dtype}, got {model_single_file.dtype}"
|
||||
|
||||
# Cleanup
|
||||
del model_single_file
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_checkpoint_variant_loading(self):
|
||||
if not self.alternate_ckpt_paths:
|
||||
return
|
||||
|
||||
for ckpt_path in self.alternate_ckpt_paths:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
single_file_kwargs = {}
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, f"Failed to load checkpoint from {ckpt_path}"
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_single_file_loading_with_device_map(self):
|
||||
single_file_kwargs = {"device_map": torch_device}
|
||||
|
||||
if self.torch_dtype:
|
||||
single_file_kwargs["torch_dtype"] = self.torch_dtype
|
||||
|
||||
model = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
|
||||
|
||||
assert model is not None, "Failed to load model with device_map"
|
||||
assert hasattr(model, "hf_device_map"), "Model should have hf_device_map attribute when loaded with device_map"
|
||||
assert model.hf_device_map is not None, "hf_device_map should not be None when loaded with device_map"
|
||||
check_device_map_is_respected(model, model.hf_device_map)
|
||||
@@ -1,220 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.training_utils import EMAModel
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
is_training,
|
||||
require_torch_accelerator_with_training,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@is_training
|
||||
@require_torch_accelerator_with_training
|
||||
class TrainingTesterMixin:
|
||||
"""
|
||||
Mixin class for testing training functionality on models.
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
- output_shape: Tuple defining the expected output shape
|
||||
|
||||
Expected methods from config mixin:
|
||||
- get_init_dict(): Returns dict of arguments to initialize the model
|
||||
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
|
||||
|
||||
Pytest mark: training
|
||||
Use `pytest -m "not training"` to skip these tests
|
||||
"""
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
|
||||
def test_training_with_ema(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
ema_model = EMAModel(model.parameters())
|
||||
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model.parameters())
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled at init"
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
assert model.is_gradient_checkpointing, "Gradient checkpointing should be enabled"
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
assert not model.is_gradient_checkpointing, "Gradient checkpointing should be disabled"
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self, expected_set=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if expected_set is None:
|
||||
pytest.skip("expected_set must be provided to verify gradient checkpointing is applied.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
|
||||
model_class_copy = copy.copy(self.model_class)
|
||||
model = model_class_copy(**init_dict)
|
||||
model.enable_gradient_checkpointing()
|
||||
|
||||
modules_with_gc_enabled = {}
|
||||
for submodule in model.modules():
|
||||
if hasattr(submodule, "gradient_checkpointing"):
|
||||
assert submodule.gradient_checkpointing, f"{submodule.__class__.__name__} should have GC enabled"
|
||||
modules_with_gc_enabled[submodule.__class__.__name__] = True
|
||||
|
||||
assert set(modules_with_gc_enabled.keys()) == expected_set, (
|
||||
f"Modules with GC enabled {set(modules_with_gc_enabled.keys())} do not match expected set {expected_set}"
|
||||
)
|
||||
assert all(modules_with_gc_enabled.values()), "All modules should have GC enabled"
|
||||
|
||||
def test_gradient_checkpointing_equivalence(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip=None):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
pytest.skip("Gradient checkpointing is not supported.")
|
||||
|
||||
if skip is None:
|
||||
skip = set()
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict_copy = copy.deepcopy(inputs_dict)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
torch.manual_seed(0)
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict_copy, return_dict=False)[0]
|
||||
|
||||
# run the backwards pass on the model
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
assert (loss - loss_2).abs() < loss_tolerance, (
|
||||
f"Loss difference {(loss - loss_2).abs()} exceeds tolerance {loss_tolerance}"
|
||||
)
|
||||
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
|
||||
for name, param in named_params.items():
|
||||
if "post_quant_conv" in name:
|
||||
continue
|
||||
if name in skip:
|
||||
continue
|
||||
if param.grad is None:
|
||||
continue
|
||||
|
||||
assert torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol), (
|
||||
f"Gradient mismatch for {name}"
|
||||
)
|
||||
|
||||
def test_mixed_precision_training(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.train()
|
||||
|
||||
# Test with float16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.float16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Test with bfloat16
|
||||
if torch.device(torch_device).type != "cpu":
|
||||
model.zero_grad()
|
||||
with torch.amp.autocast(device_type=torch.device(torch_device).type, dtype=torch.bfloat16):
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
noise = torch.randn((output.shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
||||
loss.backward()
|
||||
@@ -13,52 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, is_peft_available, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# TODO: This standalone function maintains backward compatibility with pipeline tests
|
||||
# (tests/pipelines/test_pipelines_common.py) and will be refactored.
|
||||
def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
"""Create a dummy IP Adapter state dict for Flux transformer testing."""
|
||||
def create_flux_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -68,7 +39,7 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
|
||||
joint_attention_dim = model.config["joint_attention_dim"]
|
||||
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
|
||||
sd = FluxIPAdapterAttnProcessor(
|
||||
sd = FluxIPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
|
||||
).state_dict()
|
||||
ip_cross_attn_state_dict.update(
|
||||
@@ -79,8 +50,11 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=(
|
||||
@@ -101,45 +75,57 @@ def create_flux_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
)
|
||||
|
||||
del sd
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return FluxTransformer2DModel
|
||||
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-pipe"
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {"subfolder": "transformer"}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.9]
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
"""Return Flux model initialization arguments."""
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -151,40 +137,11 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
"""Test that deprecated 3D img_ids and txt_ids still work."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -205,228 +162,63 @@ class TestFluxTransformer(FluxTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux Transformer."""
|
||||
# The test exists for cases like
|
||||
# https://github.com/huggingface/diffusers/issues/11874
|
||||
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
|
||||
def test_lora_exclude_modules(self):
|
||||
from peft import LoraConfig, get_peft_model_state_dict, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
lora_rank = 4
|
||||
target_module = "single_transformer_blocks.0.proj_out"
|
||||
adapter_name = "foo"
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterAttnProcessor
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
# Create dummy image embeds for IP adapter
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_flux_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestFluxTransformerLoRA(FluxTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerLoRAHotSwap(FluxTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
state_dict = model.state_dict()
|
||||
target_mod_shape = state_dict[f"{target_module}.weight"].shape
|
||||
lora_state_dict = {
|
||||
f"{target_module}.lora_A.weight": torch.ones(lora_rank, target_mod_shape[1]) * 22,
|
||||
f"{target_module}.lora_B.weight": torch.ones(target_mod_shape[0], lora_rank) * 33,
|
||||
}
|
||||
# Passing exclude_modules should no longer be necessary (or even passing target_modules, for that matter).
|
||||
config = LoraConfig(
|
||||
r=lora_rank, target_modules=["single_transformer_blocks.0.proj_out"], exclude_modules=["proj_out"]
|
||||
)
|
||||
inject_adapter_in_model(config, model, adapter_name=adapter_name, state_dict=lora_state_dict)
|
||||
set_peft_model_state_dict(model, lora_state_dict, adapter_name)
|
||||
retrieved_lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
|
||||
assert len(retrieved_lora_state_dict) == len(lora_state_dict)
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_A.weight"] == 22).all()
|
||||
assert (retrieved_lora_state_dict["single_transformer_blocks.0.proj_out.lora_B.weight"] == 33).all()
|
||||
|
||||
|
||||
class TestFluxTransformerCompile(FluxTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"pooled_projections": randn_tensor((batch_size, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
|
||||
class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
@property
|
||||
def ckpt_path(self):
|
||||
return "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
|
||||
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = FluxTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@property
|
||||
def alternate_ckpt_paths(self):
|
||||
return ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-flux-transformer"
|
||||
|
||||
@property
|
||||
def pretrained_model_kwargs(self):
|
||||
return {}
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompileTesterMixin):
|
||||
"""Quanto + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
"""TorchAO + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 4096, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(1, 768), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
|
||||
"img_ids": randn_tensor((4096, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"txt_ids": randn_tensor((512, 3), generator=self.generator, device=torch_device, dtype=self.torch_dtype),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
|
||||
"""ModelOpt quantization tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
|
||||
"""ModelOpt + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="torch.compile is not supported by BitsAndBytes")
|
||||
class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAndBytesCompileTesterMixin):
|
||||
"""BitsAndBytes + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTesterMixin):
|
||||
"""FasterCache tests for Flux Transformer."""
|
||||
|
||||
# Flux is guidance distilled, so we can test at model level without CFG batch handling
|
||||
FASTER_CACHE_CONFIG = {
|
||||
"spatial_attention_block_skip_range": 2,
|
||||
"spatial_attention_timestep_skip_range": (-1, 901),
|
||||
"tensor_format": "BCHW",
|
||||
"is_guidance_distilled": True,
|
||||
}
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@@ -38,7 +38,6 @@ from diffusers.utils.import_utils import (
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
is_note_seq_available,
|
||||
is_nvidia_modelopt_version,
|
||||
is_onnx_available,
|
||||
is_opencv_available,
|
||||
is_optimum_quanto_available,
|
||||
@@ -131,59 +130,6 @@ def torch_all_close(a, b, *args, **kwargs):
|
||||
return True
|
||||
|
||||
|
||||
def assert_tensors_close(
|
||||
actual: "torch.Tensor",
|
||||
expected: "torch.Tensor",
|
||||
atol: float = 1e-5,
|
||||
rtol: float = 1e-5,
|
||||
msg: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Assert that two tensors are close within tolerance.
|
||||
|
||||
Uses the same formula as torch.allclose: |actual - expected| <= atol + rtol * |expected|
|
||||
Provides concise, actionable error messages without dumping full tensors.
|
||||
|
||||
Args:
|
||||
actual: The actual tensor from the computation.
|
||||
expected: The expected tensor to compare against.
|
||||
atol: Absolute tolerance.
|
||||
rtol: Relative tolerance.
|
||||
msg: Optional message prefix for the assertion error.
|
||||
|
||||
Raises:
|
||||
AssertionError: If tensors have different shapes or values exceed tolerance.
|
||||
|
||||
Example:
|
||||
>>> assert_tensors_close(output, expected_output, atol=1e-5, rtol=1e-5, msg="Forward pass")
|
||||
"""
|
||||
if not is_torch_available():
|
||||
raise ValueError("PyTorch needs to be installed to use this function.")
|
||||
|
||||
if actual.shape != expected.shape:
|
||||
raise AssertionError(f"{msg} Shape mismatch: actual {actual.shape} vs expected {expected.shape}")
|
||||
|
||||
if not torch.allclose(actual, expected, atol=atol, rtol=rtol):
|
||||
abs_diff = (actual - expected).abs()
|
||||
max_diff = abs_diff.max().item()
|
||||
|
||||
flat_idx = abs_diff.argmax().item()
|
||||
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||
|
||||
threshold = atol + rtol * expected.abs()
|
||||
mismatched = (abs_diff > threshold).sum().item()
|
||||
total = actual.numel()
|
||||
|
||||
raise AssertionError(
|
||||
f"{msg}\n"
|
||||
f"Tensors not close! Mismatched elements: {mismatched}/{total} ({100 * mismatched / total:.1f}%)\n"
|
||||
f" Max diff: {max_diff:.6e} at index {max_idx}\n"
|
||||
f" Actual: {actual.flatten()[flat_idx].item():.6e}\n"
|
||||
f" Expected: {expected.flatten()[flat_idx].item():.6e}\n"
|
||||
f" atol: {atol:.6e}, rtol: {rtol:.6e}"
|
||||
)
|
||||
|
||||
|
||||
def numpy_cosine_similarity_distance(a, b):
|
||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||
distance = 1.0 - similarity.mean()
|
||||
@@ -295,6 +241,7 @@ def parse_flag_from_env(key, default=False):
|
||||
|
||||
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
|
||||
_run_nightly_tests = parse_flag_from_env("RUN_NIGHTLY", default=False)
|
||||
_run_compile_tests = parse_flag_from_env("RUN_COMPILE", default=False)
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
@@ -335,155 +282,12 @@ def nightly(test_case):
|
||||
|
||||
def is_torch_compile(test_case):
|
||||
"""
|
||||
Decorator marking a test as a torch.compile test. These tests can be filtered using:
|
||||
pytest -m "not compile" to skip
|
||||
pytest -m compile to run only these tests
|
||||
"""
|
||||
return pytest.mark.compile(test_case)
|
||||
Decorator marking a test that runs compile tests in the diffusers CI.
|
||||
|
||||
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
|
||||
|
||||
def is_single_file(test_case):
|
||||
"""
|
||||
Decorator marking a test as a single file loading test. These tests can be filtered using:
|
||||
pytest -m "not single_file" to skip
|
||||
pytest -m single_file to run only these tests
|
||||
"""
|
||||
return pytest.mark.single_file(test_case)
|
||||
|
||||
|
||||
def is_lora(test_case):
|
||||
"""
|
||||
Decorator marking a test as a LoRA test. These tests can be filtered using:
|
||||
pytest -m "not lora" to skip
|
||||
pytest -m lora to run only these tests
|
||||
"""
|
||||
return pytest.mark.lora(test_case)
|
||||
|
||||
|
||||
def is_ip_adapter(test_case):
|
||||
"""
|
||||
Decorator marking a test as an IP Adapter test. These tests can be filtered using:
|
||||
pytest -m "not ip_adapter" to skip
|
||||
pytest -m ip_adapter to run only these tests
|
||||
"""
|
||||
return pytest.mark.ip_adapter(test_case)
|
||||
|
||||
|
||||
def is_training(test_case):
|
||||
"""
|
||||
Decorator marking a test as a training test. These tests can be filtered using:
|
||||
pytest -m "not training" to skip
|
||||
pytest -m training to run only these tests
|
||||
"""
|
||||
return pytest.mark.training(test_case)
|
||||
|
||||
|
||||
def is_attention(test_case):
|
||||
"""
|
||||
Decorator marking a test as an attention test. These tests can be filtered using:
|
||||
pytest -m "not attention" to skip
|
||||
pytest -m attention to run only these tests
|
||||
"""
|
||||
return pytest.mark.attention(test_case)
|
||||
|
||||
|
||||
def is_memory(test_case):
|
||||
"""
|
||||
Decorator marking a test as a memory optimization test. These tests can be filtered using:
|
||||
pytest -m "not memory" to skip
|
||||
pytest -m memory to run only these tests
|
||||
"""
|
||||
return pytest.mark.memory(test_case)
|
||||
|
||||
|
||||
def is_cpu_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a CPU offload test. These tests can be filtered using:
|
||||
pytest -m "not cpu_offload" to skip
|
||||
pytest -m cpu_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.cpu_offload(test_case)
|
||||
|
||||
|
||||
def is_group_offload(test_case):
|
||||
"""
|
||||
Decorator marking a test as a group offload test. These tests can be filtered using:
|
||||
pytest -m "not group_offload" to skip
|
||||
pytest -m group_offload to run only these tests
|
||||
"""
|
||||
return pytest.mark.group_offload(test_case)
|
||||
|
||||
|
||||
def is_quantization(test_case):
|
||||
"""
|
||||
Decorator marking a test as a quantization test. These tests can be filtered using:
|
||||
pytest -m "not quantization" to skip
|
||||
pytest -m quantization to run only these tests
|
||||
"""
|
||||
return pytest.mark.quantization(test_case)
|
||||
|
||||
|
||||
def is_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator marking a test as a BitsAndBytes quantization test. These tests can be filtered using:
|
||||
pytest -m "not bitsandbytes" to skip
|
||||
pytest -m bitsandbytes to run only these tests
|
||||
"""
|
||||
return pytest.mark.bitsandbytes(test_case)
|
||||
|
||||
|
||||
def is_quanto(test_case):
|
||||
"""
|
||||
Decorator marking a test as a Quanto quantization test. These tests can be filtered using:
|
||||
pytest -m "not quanto" to skip
|
||||
pytest -m quanto to run only these tests
|
||||
"""
|
||||
return pytest.mark.quanto(test_case)
|
||||
|
||||
|
||||
def is_torchao(test_case):
|
||||
"""
|
||||
Decorator marking a test as a TorchAO quantization test. These tests can be filtered using:
|
||||
pytest -m "not torchao" to skip
|
||||
pytest -m torchao to run only these tests
|
||||
"""
|
||||
return pytest.mark.torchao(test_case)
|
||||
|
||||
|
||||
def is_gguf(test_case):
|
||||
"""
|
||||
Decorator marking a test as a GGUF quantization test. These tests can be filtered using:
|
||||
pytest -m "not gguf" to skip
|
||||
pytest -m gguf to run only these tests
|
||||
"""
|
||||
return pytest.mark.gguf(test_case)
|
||||
|
||||
|
||||
def is_modelopt(test_case):
|
||||
"""
|
||||
Decorator marking a test as a NVIDIA ModelOpt quantization test. These tests can be filtered using:
|
||||
pytest -m "not modelopt" to skip
|
||||
pytest -m modelopt to run only these tests
|
||||
"""
|
||||
return pytest.mark.modelopt(test_case)
|
||||
|
||||
|
||||
def is_context_parallel(test_case):
|
||||
"""
|
||||
Decorator marking a test as a context parallel inference test. These tests can be filtered using:
|
||||
pytest -m "not context_parallel" to skip
|
||||
pytest -m context_parallel to run only these tests
|
||||
"""
|
||||
return pytest.mark.context_parallel(test_case)
|
||||
|
||||
|
||||
def is_cache(test_case):
|
||||
"""
|
||||
Decorator marking a test as a cache test. These tests can be filtered using:
|
||||
pytest -m "not cache" to skip
|
||||
pytest -m cache to run only these tests
|
||||
"""
|
||||
return pytest.mark.cache(test_case)
|
||||
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
|
||||
|
||||
|
||||
def require_torch(test_case):
|
||||
@@ -846,16 +650,6 @@ def require_kernels_version_greater_or_equal(kernels_version):
|
||||
return decorator
|
||||
|
||||
|
||||
def require_modelopt_version_greater_or_equal(modelopt_version):
|
||||
def decorator(test_case):
|
||||
return pytest.mark.skipif(
|
||||
not is_nvidia_modelopt_version(">=", modelopt_version),
|
||||
reason=f"Test requires modelopt with version greater than {modelopt_version}.",
|
||||
)(test_case)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def deprecate_after_peft_backend(test_case):
|
||||
"""
|
||||
Decorator marking a test that will be skipped after PEFT backend
|
||||
|
||||
@@ -1,592 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Utility script to generate test suites for diffusers model classes.
|
||||
|
||||
Usage:
|
||||
python utils/generate_model_tests.py src/diffusers/models/transformers/transformer_flux.py
|
||||
|
||||
This will analyze the model file and generate a test file with appropriate
|
||||
test classes based on the model's mixins and attributes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
MIXIN_TO_TESTER = {
|
||||
"ModelMixin": "ModelTesterMixin",
|
||||
"PeftAdapterMixin": "LoraTesterMixin",
|
||||
}
|
||||
|
||||
ATTRIBUTE_TO_TESTER = {
|
||||
"_cp_plan": "ContextParallelTesterMixin",
|
||||
"_supports_gradient_checkpointing": "TrainingTesterMixin",
|
||||
}
|
||||
|
||||
ALWAYS_INCLUDE_TESTERS = [
|
||||
"ModelTesterMixin",
|
||||
"MemoryTesterMixin",
|
||||
"TorchCompileTesterMixin",
|
||||
]
|
||||
|
||||
# Attention-related class names that indicate the model uses attention
|
||||
ATTENTION_INDICATORS = {
|
||||
"AttentionMixin",
|
||||
"AttentionModuleMixin",
|
||||
}
|
||||
|
||||
OPTIONAL_TESTERS = [
|
||||
# Quantization testers
|
||||
("BitsAndBytesTesterMixin", "bnb"),
|
||||
("QuantoTesterMixin", "quanto"),
|
||||
("TorchAoTesterMixin", "torchao"),
|
||||
("GGUFTesterMixin", "gguf"),
|
||||
("ModelOptTesterMixin", "modelopt"),
|
||||
# Quantization compile testers
|
||||
("BitsAndBytesCompileTesterMixin", "bnb_compile"),
|
||||
("QuantoCompileTesterMixin", "quanto_compile"),
|
||||
("TorchAoCompileTesterMixin", "torchao_compile"),
|
||||
("GGUFCompileTesterMixin", "gguf_compile"),
|
||||
("ModelOptCompileTesterMixin", "modelopt_compile"),
|
||||
# Cache testers
|
||||
("PyramidAttentionBroadcastTesterMixin", "pab_cache"),
|
||||
("FirstBlockCacheTesterMixin", "fbc_cache"),
|
||||
("FasterCacheTesterMixin", "faster_cache"),
|
||||
# Other testers
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
]
|
||||
|
||||
|
||||
class ModelAnalyzer(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.model_classes = []
|
||||
self.current_class = None
|
||||
self.imports = set()
|
||||
|
||||
def visit_Import(self, node: ast.Import):
|
||||
for alias in node.names:
|
||||
self.imports.add(alias.name.split(".")[-1])
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ImportFrom(self, node: ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
self.imports.add(alias.name)
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_ClassDef(self, node: ast.ClassDef):
|
||||
base_names = []
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name):
|
||||
base_names.append(base.id)
|
||||
elif isinstance(base, ast.Attribute):
|
||||
base_names.append(base.attr)
|
||||
|
||||
if "ModelMixin" in base_names:
|
||||
class_info = {
|
||||
"name": node.name,
|
||||
"bases": base_names,
|
||||
"attributes": {},
|
||||
"has_forward": False,
|
||||
"init_params": [],
|
||||
}
|
||||
|
||||
for item in node.body:
|
||||
if isinstance(item, ast.Assign):
|
||||
for target in item.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
attr_name = target.id
|
||||
if attr_name.startswith("_"):
|
||||
class_info["attributes"][attr_name] = self._get_value(item.value)
|
||||
|
||||
elif isinstance(item, ast.FunctionDef):
|
||||
if item.name == "forward":
|
||||
class_info["has_forward"] = True
|
||||
class_info["forward_params"] = self._extract_func_params(item)
|
||||
elif item.name == "__init__":
|
||||
class_info["init_params"] = self._extract_func_params(item)
|
||||
|
||||
self.model_classes.append(class_info)
|
||||
|
||||
self.generic_visit(node)
|
||||
|
||||
def _extract_func_params(self, func_node: ast.FunctionDef) -> list[dict]:
|
||||
params = []
|
||||
args = func_node.args
|
||||
|
||||
num_defaults = len(args.defaults)
|
||||
num_args = len(args.args)
|
||||
first_default_idx = num_args - num_defaults
|
||||
|
||||
for i, arg in enumerate(args.args):
|
||||
if arg.arg == "self":
|
||||
continue
|
||||
|
||||
param_info = {"name": arg.arg, "type": None, "default": None}
|
||||
|
||||
if arg.annotation:
|
||||
param_info["type"] = self._get_annotation_str(arg.annotation)
|
||||
|
||||
default_idx = i - first_default_idx
|
||||
if default_idx >= 0 and default_idx < len(args.defaults):
|
||||
param_info["default"] = self._get_value(args.defaults[default_idx])
|
||||
|
||||
params.append(param_info)
|
||||
|
||||
return params
|
||||
|
||||
def _get_annotation_str(self, node) -> str:
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Constant):
|
||||
return repr(node.value)
|
||||
elif isinstance(node, ast.Subscript):
|
||||
base = self._get_annotation_str(node.value)
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
args = ", ".join(self._get_annotation_str(el) for el in node.slice.elts)
|
||||
else:
|
||||
args = self._get_annotation_str(node.slice)
|
||||
return f"{base}[{args}]"
|
||||
elif isinstance(node, ast.Attribute):
|
||||
return f"{self._get_annotation_str(node.value)}.{node.attr}"
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
left = self._get_annotation_str(node.left)
|
||||
right = self._get_annotation_str(node.right)
|
||||
return f"{left} | {right}"
|
||||
elif isinstance(node, ast.Tuple):
|
||||
return ", ".join(self._get_annotation_str(el) for el in node.elts)
|
||||
return "Any"
|
||||
|
||||
def _get_value(self, node):
|
||||
if isinstance(node, ast.Constant):
|
||||
return node.value
|
||||
elif isinstance(node, ast.Name):
|
||||
if node.id == "None":
|
||||
return None
|
||||
elif node.id == "True":
|
||||
return True
|
||||
elif node.id == "False":
|
||||
return False
|
||||
return node.id
|
||||
elif isinstance(node, ast.List):
|
||||
return [self._get_value(el) for el in node.elts]
|
||||
elif isinstance(node, ast.Dict):
|
||||
return {self._get_value(k): self._get_value(v) for k, v in zip(node.keys, node.values)}
|
||||
return "<complex>"
|
||||
|
||||
|
||||
def analyze_model_file(filepath: str) -> tuple[list[dict], set[str]]:
|
||||
with open(filepath) as f:
|
||||
source = f.read()
|
||||
|
||||
tree = ast.parse(source)
|
||||
analyzer = ModelAnalyzer()
|
||||
analyzer.visit(tree)
|
||||
|
||||
return analyzer.model_classes, analyzer.imports
|
||||
|
||||
|
||||
def determine_testers(model_info: dict, include_optional: list[str], imports: set[str]) -> list[str]:
|
||||
testers = list(ALWAYS_INCLUDE_TESTERS)
|
||||
|
||||
for base in model_info["bases"]:
|
||||
if base in MIXIN_TO_TESTER:
|
||||
tester = MIXIN_TO_TESTER[base]
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
for attr, tester in ATTRIBUTE_TO_TESTER.items():
|
||||
if attr in model_info["attributes"]:
|
||||
value = model_info["attributes"][attr]
|
||||
if value is not None and value is not False:
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
if "_cp_plan" in model_info["attributes"] and model_info["attributes"]["_cp_plan"] is not None:
|
||||
if "ContextParallelTesterMixin" not in testers:
|
||||
testers.append("ContextParallelTesterMixin")
|
||||
|
||||
# Include AttentionTesterMixin if the model imports attention-related classes
|
||||
if imports & ATTENTION_INDICATORS:
|
||||
testers.append("AttentionTesterMixin")
|
||||
|
||||
for tester, flag in OPTIONAL_TESTERS:
|
||||
if flag in include_optional:
|
||||
if tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
return testers
|
||||
|
||||
|
||||
def generate_config_class(model_info: dict, model_name: str) -> str:
|
||||
class_name = f"{model_name}TesterConfig"
|
||||
model_class = model_info["name"]
|
||||
forward_params = model_info.get("forward_params", [])
|
||||
init_params = model_info.get("init_params", [])
|
||||
|
||||
lines = [
|
||||
f"class {class_name}:",
|
||||
" @property",
|
||||
" def model_class(self):",
|
||||
f" return {model_class}",
|
||||
"",
|
||||
" @property",
|
||||
" def pretrained_model_name_or_path(self):",
|
||||
' return "" # TODO: Set Hub repository ID',
|
||||
"",
|
||||
" @property",
|
||||
" def pretrained_model_kwargs(self):",
|
||||
' return {"subfolder": "transformer"}',
|
||||
"",
|
||||
" @property",
|
||||
" def generator(self):",
|
||||
' return torch.Generator("cpu").manual_seed(0)',
|
||||
"",
|
||||
" def get_init_dict(self) -> dict[str, int | list[int]]:",
|
||||
]
|
||||
|
||||
if init_params:
|
||||
lines.append(" # __init__ parameters:")
|
||||
for param in init_params:
|
||||
type_str = f": {param['type']}" if param["type"] else ""
|
||||
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
" return {}",
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
]
|
||||
)
|
||||
|
||||
if forward_params:
|
||||
lines.append(" # forward() parameters:")
|
||||
for param in forward_params:
|
||||
type_str = f": {param['type']}" if param["type"] else ""
|
||||
default_str = f" = {param['default']}" if param["default"] is not None else ""
|
||||
lines.append(f" # {param['name']}{type_str}{default_str}")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
" # TODO: Fill in dummy inputs",
|
||||
" return {}",
|
||||
"",
|
||||
" @property",
|
||||
" def input_shape(self) -> tuple[int, ...]:",
|
||||
" return (1, 1)",
|
||||
"",
|
||||
" @property",
|
||||
" def output_shape(self) -> tuple[int, ...]:",
|
||||
" return (1, 1)",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_class(model_name: str, config_class: str, tester: str) -> str:
|
||||
tester_short = tester.replace("TesterMixin", "")
|
||||
class_name = f"Test{model_name}{tester_short}"
|
||||
|
||||
lines = [f"class {class_name}({config_class}, {tester}):"]
|
||||
|
||||
if tester == "TorchCompileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def different_shapes_for_compilation(self):",
|
||||
" return [(4, 4), (4, 8), (8, 8)]",
|
||||
"",
|
||||
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Implement dynamic input generation",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "IPAdapterTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def ip_adapter_processor_cls(self):",
|
||||
" return None # TODO: Set processor class",
|
||||
"",
|
||||
" def modify_inputs_for_ip_adapter(self, model, inputs_dict):",
|
||||
" # TODO: Add IP adapter image embeds to inputs",
|
||||
" return inputs_dict",
|
||||
"",
|
||||
" def create_ip_adapter_state_dict(self, model):",
|
||||
" # TODO: Create IP adapter state dict",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "SingleFileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def ckpt_path(self):",
|
||||
' return "" # TODO: Set checkpoint path',
|
||||
"",
|
||||
" @property",
|
||||
" def alternate_ckpt_paths(self):",
|
||||
" return []",
|
||||
"",
|
||||
" @property",
|
||||
" def pretrained_model_name_or_path(self):",
|
||||
' return "" # TODO: Set Hub repository ID',
|
||||
]
|
||||
)
|
||||
elif tester == "GGUFTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def gguf_filename(self):",
|
||||
' return "" # TODO: Set GGUF filename',
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in ["BitsAndBytesTesterMixin", "QuantoTesterMixin", "TorchAoTesterMixin", "ModelOptTesterMixin"]:
|
||||
lines.extend(
|
||||
[
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in [
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
"QuantoCompileTesterMixin",
|
||||
"TorchAoCompileTesterMixin",
|
||||
"ModelOptCompileTesterMixin",
|
||||
]:
|
||||
lines.extend(
|
||||
[
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization compile tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester == "GGUFCompileTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def gguf_filename(self):",
|
||||
' return "" # TODO: Set GGUF filename',
|
||||
"",
|
||||
" def get_dummy_inputs(self) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Override with larger inputs for quantization compile tests",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
elif tester in [
|
||||
"PyramidAttentionBroadcastTesterMixin",
|
||||
"FirstBlockCacheTesterMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
]:
|
||||
lines.append(" pass")
|
||||
elif tester == "LoraHotSwappingForModelTesterMixin":
|
||||
lines.extend(
|
||||
[
|
||||
" @property",
|
||||
" def different_shapes_for_compilation(self):",
|
||||
" return [(4, 4), (4, 8), (8, 8)]",
|
||||
"",
|
||||
" def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:",
|
||||
" # TODO: Implement dynamic input generation",
|
||||
" return {}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
lines.append(" pass")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_test_file(model_info: dict, model_filepath: str, include_optional: list[str], imports: set[str]) -> str:
|
||||
model_name = model_info["name"].replace("2DModel", "").replace("3DModel", "").replace("Model", "")
|
||||
testers = determine_testers(model_info, include_optional, imports)
|
||||
tester_imports = sorted(set(testers) - {"LoraHotSwappingForModelTesterMixin"})
|
||||
|
||||
lines = [
|
||||
"# coding=utf-8",
|
||||
"# Copyright 2025 HuggingFace Inc.",
|
||||
"#",
|
||||
'# Licensed under the Apache License, Version 2.0 (the "License");',
|
||||
"# you may not use this file except in compliance with the License.",
|
||||
"# You may obtain a copy of the License at",
|
||||
"#",
|
||||
"# http://www.apache.org/licenses/LICENSE-2.0",
|
||||
"#",
|
||||
"# Unless required by applicable law or agreed to in writing, software",
|
||||
'# distributed under the License is distributed on an "AS IS" BASIS,',
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.",
|
||||
"# See the License for the specific language governing permissions and",
|
||||
"# limitations under the License.",
|
||||
"",
|
||||
"import torch",
|
||||
"",
|
||||
f"from diffusers import {model_info['name']}",
|
||||
"from diffusers.utils.torch_utils import randn_tensor",
|
||||
"",
|
||||
"from ...testing_utils import enable_full_determinism, torch_device",
|
||||
]
|
||||
|
||||
if "LoraTesterMixin" in testers:
|
||||
lines.append("from ..test_modeling_common import LoraHotSwappingForModelTesterMixin")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"from ..testing_utils import (",
|
||||
*[f" {tester}," for tester in sorted(tester_imports)],
|
||||
")",
|
||||
"",
|
||||
"",
|
||||
"enable_full_determinism()",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
config_class = f"{model_name}TesterConfig"
|
||||
lines.append(generate_config_class(model_info, model_name))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
for tester in testers:
|
||||
lines.append(generate_test_class(model_name, config_class, tester))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
if "LoraTesterMixin" in testers:
|
||||
lines.append(generate_test_class(model_name, config_class, "LoraHotSwappingForModelTesterMixin"))
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines).rstrip() + "\n"
|
||||
|
||||
|
||||
def get_test_output_path(model_filepath: str) -> str:
|
||||
path = Path(model_filepath)
|
||||
model_filename = path.stem
|
||||
|
||||
if "transformers" in path.parts:
|
||||
return f"tests/models/transformers/test_models_{model_filename}.py"
|
||||
elif "unets" in path.parts:
|
||||
return f"tests/models/unets/test_models_{model_filename}.py"
|
||||
elif "autoencoders" in path.parts:
|
||||
return f"tests/models/autoencoders/test_models_{model_filename}.py"
|
||||
else:
|
||||
return f"tests/models/test_models_{model_filename}.py"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Generate test suite for a diffusers model class")
|
||||
parser.add_argument(
|
||||
"model_filepath",
|
||||
type=str,
|
||||
help="Path to the model file (e.g., src/diffusers/models/transformers/transformer_flux.py)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output", "-o", type=str, default=None, help="Output file path (default: auto-generated based on model path)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--include",
|
||||
"-i",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=[],
|
||||
choices=[
|
||||
"bnb",
|
||||
"quanto",
|
||||
"torchao",
|
||||
"gguf",
|
||||
"modelopt",
|
||||
"bnb_compile",
|
||||
"quanto_compile",
|
||||
"torchao_compile",
|
||||
"gguf_compile",
|
||||
"modelopt_compile",
|
||||
"pab_cache",
|
||||
"fbc_cache",
|
||||
"faster_cache",
|
||||
"single_file",
|
||||
"ip_adapter",
|
||||
"all",
|
||||
],
|
||||
help="Optional testers to include",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class-name",
|
||||
"-c",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specific model class to generate tests for (default: first model class found)",
|
||||
)
|
||||
parser.add_argument("--dry-run", action="store_true", help="Print generated code without writing to file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not Path(args.model_filepath).exists():
|
||||
print(f"Error: File not found: {args.model_filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
model_classes, imports = analyze_model_file(args.model_filepath)
|
||||
|
||||
if not model_classes:
|
||||
print(f"Error: No model classes found in {args.model_filepath}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.class_name:
|
||||
model_info = next((m for m in model_classes if m["name"] == args.class_name), None)
|
||||
if not model_info:
|
||||
available = [m["name"] for m in model_classes]
|
||||
print(f"Error: Class '{args.class_name}' not found. Available: {available}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
else:
|
||||
model_info = model_classes[0]
|
||||
if len(model_classes) > 1:
|
||||
print(f"Multiple model classes found, using: {model_info['name']}", file=sys.stderr)
|
||||
print("Use --class-name to specify a different class", file=sys.stderr)
|
||||
|
||||
include_optional = args.include
|
||||
if "all" in include_optional:
|
||||
include_optional = [flag for _, flag in OPTIONAL_TESTERS]
|
||||
|
||||
generated_code = generate_test_file(model_info, args.model_filepath, include_optional, imports)
|
||||
|
||||
if args.dry_run:
|
||||
print(generated_code)
|
||||
else:
|
||||
output_path = args.output or get_test_output_path(args.model_filepath)
|
||||
output_dir = Path(output_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "w") as f:
|
||||
f.write(generated_code)
|
||||
|
||||
print(f"Generated test file: {output_path}")
|
||||
print(f"Model class: {model_info['name']}")
|
||||
print(f"Detected attributes: {list(model_info['attributes'].keys())}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user