Compare commits

..

1 Commits

Author SHA1 Message Date
DN6
76062a74e0 update 2026-03-23 17:16:44 +05:30
32 changed files with 250 additions and 2331 deletions

View File

@@ -446,10 +446,6 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoder_kl_kvae
title: AutoencoderKLKVAE
- local: api/models/autoencoder_kl_kvae_video
title: AutoencoderKLKVAEVideo
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2

View File

@@ -1,32 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAE
The 2D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAE
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
```
## AutoencoderKLKVAE
[[autodoc]] AutoencoderKLKVAE
- decode
- all

View File

@@ -1,33 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAEVideo
The 3D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAEVideo
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
```
## AutoencoderKLKVAEVideo
[[autodoc]] AutoencoderKLKVAEVideo
- decode
- all

View File

@@ -193,8 +193,6 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLKVAE",
"AutoencoderKLKVAEVideo",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
@@ -977,8 +975,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -31,7 +32,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -326,7 +327,7 @@ class PartitionAnythingSharder:
return tensor
@lru_cache_unless_export(maxsize=64)
@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -15,7 +15,6 @@
import inspect
import json
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Literal
@@ -45,13 +44,33 @@ from .unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
},
)
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
class PeftAdapterMixin:

View File

@@ -40,8 +40,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
@@ -163,8 +161,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ..utils.torch_utils import maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,

View File

@@ -9,8 +9,6 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_kvae import AutoencoderKLKVAE
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio

View File

@@ -87,14 +87,7 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):

View File

@@ -87,14 +87,7 @@ class HunyuanVideo15RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class HunyuanVideo15AttnBlock(nn.Module):

View File

@@ -1,802 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class KVAEResnetBlock2D(nn.Module):
r"""
A Resnet block with optional guidance.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
conv_shortcut (`bool`, *optional*, default to `False`):
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
add_conv (`bool`, *optional*, default to `False`):
If `True` add conv2d layer for normalization.
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
temb_channels: int = 512,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
else:
self.nin_shortcut = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
h = self.norm1(h)
else:
h = self.norm1(h, zq)
h = self.nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h)
else:
h = self.norm2(h, zq)
h = self.nonlinearity(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class KVAEPXSDownsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
A Downsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The downsampling factor.
"""
super().__init__()
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (bchw)
pxs_interm = self.unshuffle(x)
b, c, h, w = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
pxs_out = torch.mean(pxs_interm_view, dim=2)
conv_out = self.spatial_conv(x)
# adding it all together
out = conv_out + pxs_out
return self.linear(out)
class KVAEPXSUpsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
An Upsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The upsampling factor.
"""
super().__init__()
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(self.factor**2, dim=1)
pxs_interm = self.shuffle(repeated)
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
conv_out = self.spatial_conv(image_like_ups)
# adding it all together
out = conv_out + pxs_interm
return self.linear(out)
class KVAEDecoderSpatialNorm2D(nn.Module):
r"""
A 2D normalization module for decoder.
Args:
in_channels (`int`): The number of channels in the input.
zq_channels (`int`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`):
If `True` add conv2d 3x3 layer for guidance in the beginning.
"""
def __init__(
self,
in_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
self.add_conv = add_conv
if add_conv:
self.conv = nn.Conv2d(
in_channels=zq_channels,
out_channels=zq_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
self.conv_y = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
self.conv_b = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_first = f
f_first_size = f_first.shape[2:]
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class KVAEEncoder2D(nn.Module):
r"""
A 2D encoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of output channels.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
double_z: bool = True,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
if isinstance(num_res_blocks, int):
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
else:
self.num_res_blocks = num_res_blocks
self.nonlinearity = get_activation(act_fn)
self.in_channels = in_channels
self.conv_in = nn.Conv2d(
in_channels=in_channels,
out_channels=self.ch,
kernel_size=3,
padding=(1, 1),
)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks[i_level]):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level < self.num_resolutions - 1:
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
# end
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
in_channels=block_in,
out_channels=2 * z_channels if double_z else z_channels,
kernel_size=3,
padding=(1, 1),
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks[i_level]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
else:
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
else:
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class KVAEDecoder2D(nn.Module):
r"""
A 2D decoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
out_ch (`int`): The number of output channels.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of input channels.
give_pre_end (`bool`, *optional*, default to `false`):
If `True` exit the forward pass early and return the penultimate feature map.
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
out_ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
give_pre_end: bool = False,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = nn.Conv2d(
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = KVAEPXSUpsample(in_channels=block_in)
self.up.insert(0, up)
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
self.conv_out = nn.Conv2d(
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor) -> torch.Tensor:
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
zq = z
h = self.conv_in(z)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
else:
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
else:
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
num_enc_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in encoder multiresolution layers.
num_dec_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in decoder multiresolution layers.
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels of encoder.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: int = 128,
num_enc_blocks: int = 2,
num_dec_blocks: int = 2,
z_channels: int = 16,
double_z: bool = True,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
sample_size: int = 1024,
):
super().__init__()
# pass init params to Encoder
self.encoder = KVAEEncoder2D(
in_channels=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_enc_blocks,
z_channels=z_channels,
double_z=double_z,
)
# pass init params to Decoder
self.decoder = KVAEDecoder2D(
out_ch=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_dec_blocks,
in_channels=None,
z_channels=z_channels,
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
self.tile_overlap_factor = 0.25
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -1,954 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
return F.silu(x)
# =============================================================================
# Base layers
# =============================================================================
class KVAESafeConv3d(nn.Conv3d):
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
"""
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
memory_count = input.numel() * input.element_size() / (10**9)
if memory_count > 3:
kernel_size = self.kernel_size[0]
part_num = math.ceil(memory_count / 2)
input_chunks = torch.chunk(input, part_num, dim=2)
if write_to is None:
output = []
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
output.append(super().forward(z))
return torch.cat(output, dim=2)
else:
time_offset = 0
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
z_time = z.size(2) - (kernel_size - 1)
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
time_offset += z_time
return write_to
else:
if write_to is None:
return super().forward(input)
else:
write_to[...] = super().forward(input)
return write_to
class KVAECausalConv3d(nn.Module):
r"""
A 3D causal convolution layer.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor) -> torch.Tensor:
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
input_padded = F.pad(input, padding_3d, mode="replicate")
return self.conv(input_padded)
class KVAECachedCausalConv3d(nn.Module):
r"""
A 3D causal convolution layer with caching for temporal processing.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
t_stride = self.stride[0]
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
input_parallel = F.pad(input, padding_3d, mode="replicate")
if cache["padding"] is None:
first_frame = input_parallel[:, :, :1]
time_pad_shape = list(first_frame.shape)
time_pad_shape[2] = self.time_pad
padding = first_frame.expand(time_pad_shape)
else:
padding = cache["padding"]
out_size = list(input.shape)
out_size[1] = self.conv.out_channels
if t_stride == 2:
out_size[2] = (input.size(2) + 1) // 2
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
offset_out = math.ceil(padding.size(2) / t_stride)
offset_in = offset_out * t_stride - padding.size(2)
if offset_out > 0:
padding_poisoned = torch.cat(
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
)
output[:, :, :offset_out] = self.conv(padding_poisoned)
if offset_out < output.size(2):
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
pad_offset = (
offset_in
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
+ t_stride
)
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
return output
class KVAECachedGroupNorm(nn.Module):
r"""
GroupNorm with caching support for temporal processing.
"""
def __init__(self, in_channels: int):
super().__init__()
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
out = self.norm_layer(x)
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
cache["mean"] = 1
cache["var"] = 1
return out
# =============================================================================
# Cached layers
# =============================================================================
class KVAECachedSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization for decoder with caching.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = KVAECachedGroupNorm(f_channels)
self.add_conv = add_conv
if add_conv:
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
if zq.size(2) > 1:
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = zq_first
else:
f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, cache["add_conv"])
norm_f = self.norm_layer(f, cache["norm"])
norm_f = norm_f * self.conv_y(zq)
norm_f = norm_f + self.conv_b(zq)
return norm_f
class KVAECachedResnetBlock3D(nn.Module):
r"""
A 3D ResNet block with caching.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 0,
zq_ch: Optional[int] = None,
add_conv: bool = False,
gather_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if zq_ch is None:
self.norm1 = KVAECachedGroupNorm(in_channels)
else:
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = KVAECachedGroupNorm(out_channels)
else:
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
else:
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
# Encoder path - norm takes cache
h = self.norm1(h, cache=layer_cache["norm1"])
else:
# Decoder path - spatial norm takes zq and cache
h = self.norm1(h, zq, cache=layer_cache["norm1"])
h = F.silu(h)
h = self.conv1(h, cache=layer_cache["conv1"])
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h, cache=layer_cache["norm2"])
else:
h = self.norm2(h, zq, cache=layer_cache["norm2"])
h = F.silu(h)
h = self.conv2(h, cache=layer_cache["conv2"])
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
else:
x = self.nin_shortcut(x)
return x + h
class KVAECachedPXSDownsample(nn.Module):
r"""
A 3D downsampling layer using PixelUnshuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
)
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
pxs_interm = self.unshuffle(pxs_input)
b_it, c_it, h_it, w_it = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
pxs_out = torch.mean(pxs_interm_view, dim=2)
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
conv_out = self.spatial_conv(input)
return conv_out + pxs_out
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
b, c, t, h, w = input.shape
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if cache[0]["padding"] is None:
first, rest = permuted[..., :1], permuted[..., 1:]
if rest.size(-1) > 0:
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
full_interp = torch.cat([first, rest_interp], dim=-1)
else:
full_interp = first
else:
rest = permuted
if rest.size(-1) > 0:
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
t_new = full_interp.size(-1)
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
conv_out = self.temporal_conv(input, cache[0])
return conv_out + full_interp
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
out = self.spatial_downsample(x)
if self.temporal_compress:
out = self.temporal_downsample(out, cache=cache)
return self.linear(out)
class KVAECachedPXSUpsample(nn.Module):
r"""
A 3D upsampling layer using PixelShuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 1, 1),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
)
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
out = self.spatial_conv(input_interp)
return input_interp + out
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
repeated = input.repeat_interleave(int(time_factor), dim=2)
if cache["padding"] is None:
tail = repeated[..., int(time_factor - 1) :, :, :]
else:
tail = repeated
conv_out = self.temporal_conv(tail, cache)
return conv_out + tail
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
if self.temporal_compress:
x = self.temporal_upsample(x, cache)
s_out = self.spatial_upsample(x)
to = torch.empty_like(s_out)
lin_out = self.linear(s_out, write_to=to)
return lin_out
# =============================================================================
# Cached Encoder/Decoder
# =============================================================================
class KVAECachedEncoder3D(nn.Module):
r"""
Cached 3D Encoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
in_channels: int = 3,
z_channels: int = 16,
double_z: bool = True,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
block_in = ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
else:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.norm_out = KVAECachedGroupNorm(block_in)
self.conv_out = KVAECachedCausalConv3d(
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
h = self.conv_in(x, cache=cache_dict["conv_in"])
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
)
else:
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
h = self.norm_out(h, cache=cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache=cache_dict["conv_out"])
return h
class KVAECachedDecoder3D(nn.Module):
r"""
Cached 3D Decoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 16,
zq_ch: Optional[int] = None,
add_conv: bool = False,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
else:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
self.up.insert(0, up)
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
zq = z
h = self.conv_in(z, cache_dict["conv_in"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
)
else:
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
h = self.norm_out(h, zq, cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache_dict["conv_out"])
return h
# =============================================================================
# Main AutoencoderKL class
# =============================================================================
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[KVAE](https://github.com/kandinskylab/kvae-1).
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
ch (`int`, *optional*, defaults to 128): Base channel count.
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["KVAECachedResnetBlock3D"]
@register_to_config
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
in_channels: int = 3,
out_ch: int = 3,
z_channels: int = 16,
temporal_compress_times: int = 4,
):
super().__init__()
self.encoder = KVAECachedEncoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
in_channels=in_channels,
z_channels=z_channels,
double_z=True,
temporal_compress_times=temporal_compress_times,
)
self.decoder = KVAECachedDecoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
out_ch=out_ch,
z_channels=z_channels,
temporal_compress_times=temporal_compress_times,
)
self.use_slicing = False
self.use_tiling = False
def _make_encoder_cache(self) -> Dict:
"""Create empty cache for cached encoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_enc"),
"mid_2": make_dict("resblock_enc"),
"norm_out": make_dict("norm_enc"),
"conv_out": make_dict("conv"),
}
# Encoder uses num_res_blocks per level
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
return cache
def _make_decoder_cache(self) -> Dict:
"""Create empty cache for decoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_dec"),
"mid_2": make_dict("resblock_dec"),
"norm_out": make_dict("norm_dec"),
"conv_out": make_dict("conv"),
}
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
return cache
def enable_slicing(self) -> None:
r"""Enable sliced VAE decoding."""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""Disable sliced VAE decoding."""
self.use_slicing = False
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
# Cached encoder processes by segments
cache = self._make_encoder_cache()
split_list = [seg_len + 1]
n_frames = x.size(2) - (seg_len + 1)
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
latent = []
for chunk in torch.split(x, split_list, dim=2):
l = self.encoder(chunk, cache)
sample, _ = torch.chunk(l, 2, dim=1)
latent.append(sample)
return torch.cat(latent, dim=2)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
# For cached encoder, we already did the split in _encode
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
posterior = DiagonalGaussianDistribution(h_double)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
cache = self._make_decoder_cache()
temporal_compress = self.config.temporal_compress_times
split_list = [seg_len + 1]
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
split_list = [math.ceil(size / temporal_compress) for size in split_list]
recs = []
for chunk in torch.split(z, split_list, dim=2):
out = self.decoder(chunk, cache)
recs.append(out)
return torch.cat(recs, dim=2)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
Args:
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -105,14 +105,7 @@ class QwenImageRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class QwenImageUpsample(nn.Upsample):

View File

@@ -196,14 +196,7 @@ class WanRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -24,7 +25,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -306,7 +307,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -427,7 +428,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -449,7 +450,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -933,7 +934,6 @@ class QwenImageTransformer2DModel(
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):

View File

@@ -788,12 +788,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
if all(seq == max_seqlen for seq in item_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
@@ -874,12 +871,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
if all(seq == max_seqlen for seq in unified_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None

View File

@@ -16,29 +16,22 @@ from typing import Callable
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import (
is_cosmos_guardrail_available,
is_torch_xla_available,
is_torchvision_available,
logging,
replace_example_docstring,
)
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_torchvision_available():
import torchvision.transforms.functional
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:

View File

@@ -521,36 +521,6 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -29,7 +29,6 @@ from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .deprecation_utils import deprecate
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -68,11 +67,9 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
deprecate(
"diffusers.utils.testing_utils",
"1.0.0",
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version

View File

@@ -19,16 +19,11 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -338,23 +333,5 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
if is_torch_version("<", "2.7.0"):
return cached
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()

View File

@@ -1,73 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAE
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_config(self):
return {
"in_channels": 3,
"channels": 32,
"num_enc_blocks": 1,
"num_dec_blocks": 1,
"z_channels": 4,
"double_z": True,
"ch_mult": (1, 2),
"sample_size": 32,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -1,118 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAEVideo
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_video_config(self):
return {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 1,
"in_channels": 3,
"out_ch": 3,
"z_channels": 4,
"temporal_compress_times": 2,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": video}
@property
def input_shape(self):
return (3, 3, 16, 16)
@property
def output_shape(self):
return (3, 3, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass
def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_training(self):
self._run_nondeterministic(super().test_training)
def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
@unittest.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
)
def test_effective_gradient_checkpointing(self):
pass
def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)

View File

@@ -481,8 +481,6 @@ class LoraHotSwappingForModelTesterMixin:
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -490,31 +488,21 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
finally:
diffusers_logging.disable_propagation()
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -530,26 +518,20 @@ class LoraHotSwappingForModelTesterMixin:
# check the error and log
import logging
from diffusers.utils import logging as diffusers_logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
diffusers_logging.enable_propagation()
try:
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")

View File

@@ -22,7 +22,6 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from ...testing_utils import (
is_context_parallel,
@@ -161,21 +160,16 @@ def _custom_mesh_worker(
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
@@ -200,10 +194,6 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
@@ -219,11 +209,6 @@ class ContextParallelTesterMixin:
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}

View File

@@ -150,7 +150,8 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3

View File

@@ -90,7 +90,8 @@ class Flux2TransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32

View File

@@ -14,7 +14,6 @@
import warnings
import pytest
import torch
from diffusers import QwenImageTransformer2DModel
@@ -78,7 +77,8 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
@@ -106,10 +106,9 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_infers_text_seq_len_from_mask(self, batch_size):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -123,7 +122,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2 * batch_size
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -140,7 +139,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5 * batch_size
assert normalized_mask2.sum().item() == 5
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
@@ -150,10 +149,9 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert per_sample_len_none is None
assert normalized_mask_none is None
@pytest.mark.parametrize("batch_size", [1, 2])
def test_non_contiguous_attention_mask(self, batch_size):
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -286,14 +284,6 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_linear(self):
super().test_hotswapping_compiled_model_linear()
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_both_linear_and_other(self):
super().test_hotswapping_compiled_model_both_linear_and_other()
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]
class SanaTransformer2DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaTransformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
def output_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.9]
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"out_channels": 4,
@@ -75,9 +77,53 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 8,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaTransformer2D(SanaTransformer2DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Transformer 2D."""
class TestSanaTransformer2DMemory(SanaTransformer2DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTraining(SanaTransformer2DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Transformer 2D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSanaTransformer2DAttention(SanaTransformer2DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Transformer 2D."""
class TestSanaTransformer2DCompile(SanaTransformer2DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Transformer 2D."""
class TestSanaTransformer2DBitsAndBytes(SanaTransformer2DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTorchAo(SanaTransformer2DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Transformer 2D."""

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class SanaVideoTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | float | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 2, 16, 16)
@property
def output_shape(self):
return (16, 2, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 2,
@@ -82,16 +80,56 @@ class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaVideoTransformer3D(SanaVideoTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DMemory(SanaVideoTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTraining(SanaVideoTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Video Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
class TestSanaVideoTransformer3DAttention(SanaVideoTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Video Transformer 3D."""
def prepare_init_args_and_inputs_for_common(self):
return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestSanaVideoTransformer3DCompile(SanaVideoTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DBitsAndBytes(SanaVideoTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTorchAo(SanaVideoTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Video Transformer 3D."""

View File

@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import unittest
import warnings
import pytest
@@ -184,25 +182,6 @@ class DeprecateTester(unittest.TestCase):
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename
def test_deprecate_testing_utils_module(self):
import diffusers.utils.testing_utils
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
importlib.reload(diffusers.utils.testing_utils)
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
messages = [str(w.message) for w in deprecation_warnings]
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
)
assert any(
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
for msg in messages
), f"Expected deprecation message substring not found, got: {messages}"
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):

View File

@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
def fetch_pipeline_objects():
models = api.list_models(filter="diffusers")
models = api.list_models(library="diffusers")
downloads = defaultdict(int)
for model in models: