Compare commits

..

16 Commits

Author SHA1 Message Date
Sayak Paul
17257b5d68 Merge branch 'main' into unet-model-tests-refactor 2026-03-23 15:27:34 +05:30
Sayak Paul
ecfd3b4f99 Merge branch 'main' into unet-model-tests-refactor 2026-02-16 16:37:07 +05:30
sayakpaul
5f8303fe3c remove test suites that are passed. 2026-02-16 16:36:06 +05:30
sayakpaul
99de4ceab8 [tests] refactor test_models_unet_spatiotemporal.py to use modular testing mixins
Refactored the spatiotemporal UNet test file to follow the modern modular testing
pattern with BaseModelTesterConfig and focused test classes:

- UNetSpatioTemporalTesterConfig: Base configuration with model setup
- TestUNetSpatioTemporal: Core model tests (ModelTesterMixin, UNetTesterMixin)
- TestUNetSpatioTemporalAttention: Attention-related tests (AttentionTesterMixin)
- TestUNetSpatioTemporalMemory: Memory/offloading tests (MemoryTesterMixin)
- TestUNetSpatioTemporalTraining: Training tests (TrainingTesterMixin)
- TestUNetSpatioTemporalLoRA: LoRA adapter tests (LoraTesterMixin)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:13:42 +05:30
sayakpaul
c6e6992cdd [tests] refactor test_models_unet_controlnetxs.py to use modular testing mixins
Refactor UNetControlNetXSModel tests to follow the modern testing
pattern with separate classes for core, memory, training, and LoRA.
Specialized tests (from_unet, freeze_unet, forward_no_control,
time_embedding_mixing) remain in the core test class.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:11:00 +05:30
sayakpaul
ecbaed793d [tests] refactor test_models_unet_3d_condition.py to use modular testing mixins
Refactor UNet3DConditionModel tests to follow the modern testing pattern
with separate classes for core, attention, memory, training, and LoRA.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:09:42 +05:30
sayakpaul
0411da7739 [tests] refactor test_models_unet_2d.py to use modular testing mixins
Refactor UNet2D model tests (standard, LDM, NCSN++) to follow the
modern testing pattern. Each variant gets its own config class and
dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:08:35 +05:30
sayakpaul
ffb254a273 [tests] refactor test_models_unet_1d.py to use modular testing mixins
Refactor UNet1D model tests to follow the modern testing pattern using
BaseModelTesterConfig and focused mixin classes (ModelTesterMixin,
MemoryTesterMixin, TrainingTesterMixin, LoraTesterMixin).

Both UNet1D standard and RL variants now have separate config classes
and dedicated test classes organized by concern (core, memory, training,
LoRA, hub loading).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 16:05:32 +05:30
sayakpaul
ea08148bbd recompile limit 2026-02-16 15:41:57 +05:30
sayakpaul
3a610814a3 Merge branch 'main' into unet-model-tests-refactor 2026-02-16 15:41:00 +05:30
sayakpaul
ca4a7b0649 up 2026-02-16 15:40:24 +05:30
sayakpaul
3371560f1d Revert "fix"
This reverts commit 46d44b73d8.
2026-02-16 13:34:24 +05:30
sayakpaul
46d44b73d8 fix 2026-02-16 13:30:54 +05:30
sayakpaul
2b67fb65ef up 2026-02-16 13:10:04 +05:30
sayakpaul
0e42a3ff93 fix tests 2026-02-16 11:59:33 +05:30
sayakpaul
14439ab793 refactor unet2d condition model tests. 2026-02-16 10:08:41 +05:30
31 changed files with 1027 additions and 3610 deletions

View File

@@ -446,10 +446,6 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoder_kl_kvae
title: AutoencoderKLKVAE
- local: api/models/autoencoder_kl_kvae_video
title: AutoencoderKLKVAEVideo
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2

View File

@@ -1,32 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAE
The 2D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAE
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
```
## AutoencoderKLKVAE
[[autodoc]] AutoencoderKLKVAE
- decode
- all

View File

@@ -1,33 +0,0 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAEVideo
The 3D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAEVideo
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
```
## AutoencoderKLKVAEVideo
[[autodoc]] AutoencoderKLKVAEVideo
- decode
- all

View File

@@ -193,8 +193,6 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLKVAE",
"AutoencoderKLKVAEVideo",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
@@ -977,8 +975,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -31,7 +32,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -326,7 +327,7 @@ class PartitionAnythingSharder:
return tensor
@lru_cache_unless_export(maxsize=64)
@functools.lru_cache(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -40,8 +40,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
@@ -163,8 +161,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ..utils.torch_utils import maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,

View File

@@ -9,8 +9,6 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_kvae import AutoencoderKLKVAE
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio

View File

@@ -1,802 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class KVAEResnetBlock2D(nn.Module):
r"""
A Resnet block with optional guidance.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
conv_shortcut (`bool`, *optional*, default to `False`):
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
add_conv (`bool`, *optional*, default to `False`):
If `True` add conv2d layer for normalization.
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
temb_channels: int = 512,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
else:
self.nin_shortcut = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
h = self.norm1(h)
else:
h = self.norm1(h, zq)
h = self.nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h)
else:
h = self.norm2(h, zq)
h = self.nonlinearity(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class KVAEPXSDownsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
A Downsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The downsampling factor.
"""
super().__init__()
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (bchw)
pxs_interm = self.unshuffle(x)
b, c, h, w = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
pxs_out = torch.mean(pxs_interm_view, dim=2)
conv_out = self.spatial_conv(x)
# adding it all together
out = conv_out + pxs_out
return self.linear(out)
class KVAEPXSUpsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
An Upsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The upsampling factor.
"""
super().__init__()
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(self.factor**2, dim=1)
pxs_interm = self.shuffle(repeated)
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
conv_out = self.spatial_conv(image_like_ups)
# adding it all together
out = conv_out + pxs_interm
return self.linear(out)
class KVAEDecoderSpatialNorm2D(nn.Module):
r"""
A 2D normalization module for decoder.
Args:
in_channels (`int`): The number of channels in the input.
zq_channels (`int`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`):
If `True` add conv2d 3x3 layer for guidance in the beginning.
"""
def __init__(
self,
in_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
self.add_conv = add_conv
if add_conv:
self.conv = nn.Conv2d(
in_channels=zq_channels,
out_channels=zq_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
self.conv_y = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
self.conv_b = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_first = f
f_first_size = f_first.shape[2:]
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class KVAEEncoder2D(nn.Module):
r"""
A 2D encoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of output channels.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
double_z: bool = True,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
if isinstance(num_res_blocks, int):
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
else:
self.num_res_blocks = num_res_blocks
self.nonlinearity = get_activation(act_fn)
self.in_channels = in_channels
self.conv_in = nn.Conv2d(
in_channels=in_channels,
out_channels=self.ch,
kernel_size=3,
padding=(1, 1),
)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks[i_level]):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level < self.num_resolutions - 1:
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
# end
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
in_channels=block_in,
out_channels=2 * z_channels if double_z else z_channels,
kernel_size=3,
padding=(1, 1),
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks[i_level]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
else:
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
else:
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class KVAEDecoder2D(nn.Module):
r"""
A 2D decoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
out_ch (`int`): The number of output channels.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of input channels.
give_pre_end (`bool`, *optional*, default to `false`):
If `True` exit the forward pass early and return the penultimate feature map.
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
out_ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
give_pre_end: bool = False,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = nn.Conv2d(
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = KVAEPXSUpsample(in_channels=block_in)
self.up.insert(0, up)
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
self.conv_out = nn.Conv2d(
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor) -> torch.Tensor:
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
zq = z
h = self.conv_in(z)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
else:
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
else:
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
num_enc_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in encoder multiresolution layers.
num_dec_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in decoder multiresolution layers.
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels of encoder.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: int = 128,
num_enc_blocks: int = 2,
num_dec_blocks: int = 2,
z_channels: int = 16,
double_z: bool = True,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
sample_size: int = 1024,
):
super().__init__()
# pass init params to Encoder
self.encoder = KVAEEncoder2D(
in_channels=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_enc_blocks,
z_channels=z_channels,
double_z=double_z,
)
# pass init params to Decoder
self.decoder = KVAEDecoder2D(
out_ch=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_dec_blocks,
in_channels=None,
z_channels=z_channels,
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
self.tile_overlap_factor = 0.25
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -1,954 +0,0 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
return F.silu(x)
# =============================================================================
# Base layers
# =============================================================================
class KVAESafeConv3d(nn.Conv3d):
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
"""
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
memory_count = input.numel() * input.element_size() / (10**9)
if memory_count > 3:
kernel_size = self.kernel_size[0]
part_num = math.ceil(memory_count / 2)
input_chunks = torch.chunk(input, part_num, dim=2)
if write_to is None:
output = []
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
output.append(super().forward(z))
return torch.cat(output, dim=2)
else:
time_offset = 0
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
z_time = z.size(2) - (kernel_size - 1)
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
time_offset += z_time
return write_to
else:
if write_to is None:
return super().forward(input)
else:
write_to[...] = super().forward(input)
return write_to
class KVAECausalConv3d(nn.Module):
r"""
A 3D causal convolution layer.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor) -> torch.Tensor:
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
input_padded = F.pad(input, padding_3d, mode="replicate")
return self.conv(input_padded)
class KVAECachedCausalConv3d(nn.Module):
r"""
A 3D causal convolution layer with caching for temporal processing.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
t_stride = self.stride[0]
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
input_parallel = F.pad(input, padding_3d, mode="replicate")
if cache["padding"] is None:
first_frame = input_parallel[:, :, :1]
time_pad_shape = list(first_frame.shape)
time_pad_shape[2] = self.time_pad
padding = first_frame.expand(time_pad_shape)
else:
padding = cache["padding"]
out_size = list(input.shape)
out_size[1] = self.conv.out_channels
if t_stride == 2:
out_size[2] = (input.size(2) + 1) // 2
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
offset_out = math.ceil(padding.size(2) / t_stride)
offset_in = offset_out * t_stride - padding.size(2)
if offset_out > 0:
padding_poisoned = torch.cat(
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
)
output[:, :, :offset_out] = self.conv(padding_poisoned)
if offset_out < output.size(2):
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
pad_offset = (
offset_in
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
+ t_stride
)
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
return output
class KVAECachedGroupNorm(nn.Module):
r"""
GroupNorm with caching support for temporal processing.
"""
def __init__(self, in_channels: int):
super().__init__()
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
out = self.norm_layer(x)
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
cache["mean"] = 1
cache["var"] = 1
return out
# =============================================================================
# Cached layers
# =============================================================================
class KVAECachedSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization for decoder with caching.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = KVAECachedGroupNorm(f_channels)
self.add_conv = add_conv
if add_conv:
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
if zq.size(2) > 1:
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = zq_first
else:
f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, cache["add_conv"])
norm_f = self.norm_layer(f, cache["norm"])
norm_f = norm_f * self.conv_y(zq)
norm_f = norm_f + self.conv_b(zq)
return norm_f
class KVAECachedResnetBlock3D(nn.Module):
r"""
A 3D ResNet block with caching.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 0,
zq_ch: Optional[int] = None,
add_conv: bool = False,
gather_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if zq_ch is None:
self.norm1 = KVAECachedGroupNorm(in_channels)
else:
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = KVAECachedGroupNorm(out_channels)
else:
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
else:
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
# Encoder path - norm takes cache
h = self.norm1(h, cache=layer_cache["norm1"])
else:
# Decoder path - spatial norm takes zq and cache
h = self.norm1(h, zq, cache=layer_cache["norm1"])
h = F.silu(h)
h = self.conv1(h, cache=layer_cache["conv1"])
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h, cache=layer_cache["norm2"])
else:
h = self.norm2(h, zq, cache=layer_cache["norm2"])
h = F.silu(h)
h = self.conv2(h, cache=layer_cache["conv2"])
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
else:
x = self.nin_shortcut(x)
return x + h
class KVAECachedPXSDownsample(nn.Module):
r"""
A 3D downsampling layer using PixelUnshuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
)
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
pxs_interm = self.unshuffle(pxs_input)
b_it, c_it, h_it, w_it = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
pxs_out = torch.mean(pxs_interm_view, dim=2)
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
conv_out = self.spatial_conv(input)
return conv_out + pxs_out
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
b, c, t, h, w = input.shape
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if cache[0]["padding"] is None:
first, rest = permuted[..., :1], permuted[..., 1:]
if rest.size(-1) > 0:
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
full_interp = torch.cat([first, rest_interp], dim=-1)
else:
full_interp = first
else:
rest = permuted
if rest.size(-1) > 0:
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
t_new = full_interp.size(-1)
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
conv_out = self.temporal_conv(input, cache[0])
return conv_out + full_interp
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
out = self.spatial_downsample(x)
if self.temporal_compress:
out = self.temporal_downsample(out, cache=cache)
return self.linear(out)
class KVAECachedPXSUpsample(nn.Module):
r"""
A 3D upsampling layer using PixelShuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 1, 1),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
)
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
out = self.spatial_conv(input_interp)
return input_interp + out
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
repeated = input.repeat_interleave(int(time_factor), dim=2)
if cache["padding"] is None:
tail = repeated[..., int(time_factor - 1) :, :, :]
else:
tail = repeated
conv_out = self.temporal_conv(tail, cache)
return conv_out + tail
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
if self.temporal_compress:
x = self.temporal_upsample(x, cache)
s_out = self.spatial_upsample(x)
to = torch.empty_like(s_out)
lin_out = self.linear(s_out, write_to=to)
return lin_out
# =============================================================================
# Cached Encoder/Decoder
# =============================================================================
class KVAECachedEncoder3D(nn.Module):
r"""
Cached 3D Encoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
in_channels: int = 3,
z_channels: int = 16,
double_z: bool = True,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
block_in = ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
else:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.norm_out = KVAECachedGroupNorm(block_in)
self.conv_out = KVAECachedCausalConv3d(
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
h = self.conv_in(x, cache=cache_dict["conv_in"])
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
)
else:
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
h = self.norm_out(h, cache=cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache=cache_dict["conv_out"])
return h
class KVAECachedDecoder3D(nn.Module):
r"""
Cached 3D Decoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 16,
zq_ch: Optional[int] = None,
add_conv: bool = False,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
else:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
self.up.insert(0, up)
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
zq = z
h = self.conv_in(z, cache_dict["conv_in"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
)
else:
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
h = self.norm_out(h, zq, cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache_dict["conv_out"])
return h
# =============================================================================
# Main AutoencoderKL class
# =============================================================================
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[KVAE](https://github.com/kandinskylab/kvae-1).
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
ch (`int`, *optional*, defaults to 128): Base channel count.
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["KVAECachedResnetBlock3D"]
@register_to_config
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
in_channels: int = 3,
out_ch: int = 3,
z_channels: int = 16,
temporal_compress_times: int = 4,
):
super().__init__()
self.encoder = KVAECachedEncoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
in_channels=in_channels,
z_channels=z_channels,
double_z=True,
temporal_compress_times=temporal_compress_times,
)
self.decoder = KVAECachedDecoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
out_ch=out_ch,
z_channels=z_channels,
temporal_compress_times=temporal_compress_times,
)
self.use_slicing = False
self.use_tiling = False
def _make_encoder_cache(self) -> Dict:
"""Create empty cache for cached encoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_enc"),
"mid_2": make_dict("resblock_enc"),
"norm_out": make_dict("norm_enc"),
"conv_out": make_dict("conv"),
}
# Encoder uses num_res_blocks per level
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
return cache
def _make_decoder_cache(self) -> Dict:
"""Create empty cache for decoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_dec"),
"mid_2": make_dict("resblock_dec"),
"norm_out": make_dict("norm_dec"),
"conv_out": make_dict("conv"),
}
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
return cache
def enable_slicing(self) -> None:
r"""Enable sliced VAE decoding."""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""Disable sliced VAE decoding."""
self.use_slicing = False
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
# Cached encoder processes by segments
cache = self._make_encoder_cache()
split_list = [seg_len + 1]
n_frames = x.size(2) - (seg_len + 1)
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
latent = []
for chunk in torch.split(x, split_list, dim=2):
l = self.encoder(chunk, cache)
sample, _ = torch.chunk(l, 2, dim=1)
latent.append(sample)
return torch.cat(latent, dim=2)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
# For cached encoder, we already did the split in _encode
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
posterior = DiagonalGaussianDistribution(h_double)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
cache = self._make_decoder_cache()
temporal_compress = self.config.temporal_compress_times
split_list = [seg_len + 1]
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
split_list = [math.ceil(size / temporal_compress) for size in split_list]
recs = []
for chunk in torch.split(z, split_list, dim=2):
out = self.decoder(chunk, cache)
recs.append(out)
return torch.cat(recs, dim=2)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
Args:
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -24,7 +25,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -306,7 +307,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=128)
@functools.lru_cache(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -427,7 +428,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -449,7 +450,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@lru_cache_unless_export(maxsize=None)
@functools.lru_cache(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs

View File

@@ -521,36 +521,6 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -29,7 +29,6 @@ from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .deprecation_utils import deprecate
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -68,11 +67,9 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
deprecate(
"diffusers.utils.testing_utils",
"1.0.0",
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version

View File

@@ -19,16 +19,11 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -338,23 +333,5 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
if is_torch_version("<", "2.7.0"):
return cached
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()

View File

@@ -1,73 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAE
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_config(self):
return {
"in_channels": 3,
"channels": 32,
"num_enc_blocks": 1,
"num_dec_blocks": 1,
"z_channels": 4,
"double_z": True,
"ch_mult": (1, 2),
"sample_size": 32,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -1,118 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAEVideo
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_video_config(self):
return {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 1,
"in_channels": 3,
"out_ch": 3,
"z_channels": 4,
"temporal_compress_times": 2,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": video}
@property
def input_shape(self):
return (3, 3, 16, 16)
@property
def output_shape(self):
return (3, 3, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass
def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_training(self):
self._run_nondeterministic(super().test_training)
def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
@unittest.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
)
def test_effective_gradient_checkpointing(self):
pass
def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)

View File

@@ -465,7 +465,8 @@ class UNetTesterMixin:
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
@@ -480,9 +481,9 @@ class UNetTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
class ModelTesterMixin:

View File

@@ -287,8 +287,9 @@ class ModelTesterMixin:
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -308,8 +309,9 @@ class ModelTesterMixin:
new_model.to(torch_device)
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -337,8 +339,9 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict, return_dict=False)[0]
second = model(**inputs_dict, return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
@@ -395,8 +398,9 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
inputs_dict = self.get_dummy_inputs()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@@ -523,8 +527,10 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
@@ -563,8 +569,10 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
@@ -614,8 +622,10 @@ class ModelTesterMixin:
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"

View File

@@ -92,9 +92,6 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),

View File

@@ -15,6 +15,7 @@
import gc
import json
import logging
import os
import re
@@ -23,10 +24,12 @@ import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils import logging as diffusers_logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
CaptureLogger,
assert_tensors_close,
backend_empty_cache,
is_lora,
@@ -477,12 +480,7 @@ class LoraHotSwappingForModelTesterMixin:
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
from diffusers.utils import logging as diffusers_logging
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -490,31 +488,26 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
from diffusers.utils import logging as diffusers_logging
assert msg in str(cap_logger.out), f"Expected warning not found. Captured: {cap_logger.out}"
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
finally:
diffusers_logging.disable_propagation()
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert cap_logger.out == "", f"Expected no warnings but found: {cap_logger.out}"
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -527,29 +520,20 @@ class LoraHotSwappingForModelTesterMixin:
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
from diffusers.utils import logging as diffusers_logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
diffusers_logging.enable_propagation()
try:
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")

View File

@@ -22,7 +22,6 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from ...testing_utils import (
is_context_parallel,
@@ -161,21 +160,16 @@ def _custom_mesh_worker(
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
def test_context_parallel_inference(self, cp_type):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
@@ -200,11 +194,6 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
@@ -220,11 +209,6 @@ class ContextParallelTesterMixin:
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}

View File

@@ -41,6 +41,7 @@ from ..testing_utils import (
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
@@ -150,7 +151,8 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
height = width = 4
num_latent_channels = 4
num_image_channels = 3
@@ -217,10 +219,6 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
@@ -414,6 +412,10 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
"""BitsAndBytes + compile tests for Flux Transformer."""
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""

View File

@@ -13,94 +13,48 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import Flux2Transformer2DModel
from diffusers.models.transformers.transformer_flux2 import (
Flux2KVAttnProcessor,
Flux2KVCache,
Flux2KVLayerCache,
Flux2KVParallelSelfAttnProcessor,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers import Flux2Transformer2DModel, attention_backend
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return Flux2Transformer2DModel
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, int]:
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
def output_shape(self):
return (16, 4)
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
# Skip setting testing with default: AttnProcessor
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
@@ -128,286 +82,8 @@ class Flux2TransformerTesterConfig(BaseModelTesterConfig):
"guidance": guidance,
}
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
pass
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux2 Transformer."""
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux2 Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux2 Transformer."""
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux2 Transformer."""
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux2 Transformer."""
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux2 Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux2 Transformer."""
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux2 Transformer."""
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux2 Transformer."""
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
num_ref_tokens = 4
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -415,210 +91,72 @@ class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
num_ref_tokens = self.num_ref_tokens
inputs_dict = self.dummy_input
return init_dict, inputs_dict
ref_hidden_states = randn_tensor(
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
)
img_hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
# TODO (Daniel, Sayak): We can remove this test.
def test_flux2_consistency(self, seed=0):
torch.manual_seed(seed)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
ref_t_coords = torch.arange(1)
ref_h_coords = torch.arange(num_ref_tokens)
ref_w_coords = torch.arange(1)
ref_l_coords = torch.arange(1)
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
image_ids = torch.cat([ref_ids, image_ids], dim=1)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
"""KV cache tests for Flux2 Transformer."""
def test_kv_layer_cache_store_and_get(self):
cache = Flux2KVLayerCache()
k = torch.randn(1, 4, 2, 16)
v = torch.randn(1, 4, 2, 16)
cache.store(k, v)
k_out, v_out = cache.get()
assert torch.equal(k, k_out)
assert torch.equal(v, v_out)
def test_kv_layer_cache_get_before_store_raises(self):
cache = Flux2KVLayerCache()
try:
cache.get()
assert False, "Expected RuntimeError"
except RuntimeError:
pass
def test_kv_layer_cache_clear(self):
cache = Flux2KVLayerCache()
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.k_ref is None
assert cache.v_ref is None
def test_kv_cache_structure(self):
num_double = 3
num_single = 2
cache = Flux2KVCache(num_double, num_single)
assert len(cache.double_block_caches) == num_double
assert len(cache.single_block_caches) == num_single
assert cache.num_ref_tokens == 0
for i in range(num_double):
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
for i in range(num_single):
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
def test_kv_cache_clear(self):
cache = Flux2KVCache(2, 1)
cache.num_ref_tokens = 4
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.num_ref_tokens == 0
assert cache.get_double(0).k_ref is None
def _set_kv_attn_processors(self, model):
for block in model.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in model.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
@torch.no_grad()
def test_extract_mode_returns_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self._set_kv_attn_processors(model)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
assert output.kv_cache is not None
assert isinstance(output.kv_cache, Flux2KVCache)
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
for layer_cache in output.kv_cache.double_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
for layer_cache in output.kv_cache.single_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
@torch.no_grad()
def test_extract_mode_output_shape(self):
model = self.model_class(**self.get_init_dict())
torch.manual_seed(seed)
model = self.model_class(**init_dict)
# state_dict = model.state_dict()
# for key, param in state_dict.items():
# print(f"{key} | {param.shape}")
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
model.to(torch_device)
model.eval()
height, width = 4, 4
output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
with attention_backend("native"):
with torch.no_grad():
output = model(**inputs_dict)
assert output.sample.shape == (1, height * width, 4)
if isinstance(output, dict):
output = output.to_tuple()[0]
@torch.no_grad()
def test_cached_mode_uses_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self.assertIsNotNone(output)
height, width = 4, 4
extract_output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
base_config = Flux2TransformerTesterConfig()
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
cached_output = model(
**cached_inputs,
kv_cache=extract_output.kv_cache,
kv_cache_mode="cached",
)
# Check against expected slice
# fmt: off
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
# fmt: on
assert cached_output.sample.shape == (1, height * width, 4)
assert cached_output.kv_cache is None
flat_output = output.cpu().flatten()
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
@torch.no_grad()
def test_extract_return_dict_false(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
return_dict=False,
)
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[1], Flux2KVCache)
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
@torch.no_grad()
def test_no_kv_cache_mode_returns_no_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
base_config = Flux2TransformerTesterConfig()
output = model(**base_config.get_dummy_inputs())
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
assert output.kv_cache is None
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)

View File

@@ -14,7 +14,6 @@
import warnings
import pytest
import torch
from diffusers import QwenImageTransformer2DModel
@@ -78,7 +77,8 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
@@ -106,10 +106,9 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_infers_text_seq_len_from_mask(self, batch_size):
def test_infers_text_seq_len_from_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -123,7 +122,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2 * batch_size
assert normalized_mask.sum().item() == 2
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -140,7 +139,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5 * batch_size
assert normalized_mask2.sum().item() == 5
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
@@ -150,10 +149,9 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert per_sample_len_none is None
assert normalized_mask_none is None
@pytest.mark.parametrize("batch_size", [1, 2])
def test_non_contiguous_attention_mask(self, batch_size):
def test_non_contiguous_attention_mask(self):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs(batch_size=batch_size)
inputs = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -286,14 +284,6 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_linear(self):
super().test_hotswapping_compiled_model_linear()
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_both_linear_and_other(self):
super().test_hotswapping_compiled_model_both_linear_and_other()
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

View File

@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
@@ -26,64 +24,39 @@ from ...testing_utils import (
slow,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
)
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
_LAYERWISE_CASTING_XFAIL_REASON = (
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
)
class UNet1DTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (standard variant)."""
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
def model_class(self):
return UNet1DModel
@property
def output_shape(self):
return (4, 14, 16)
return (14, 16)
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@property
def main_input_name(self):
return "sample"
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def test_determinism(self):
super().test_determinism()
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
@@ -97,18 +70,40 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "swish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DHubLoading(UNet1DTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -131,12 +126,7 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
@slow
def test_unet_1d_maestro(self):
@@ -157,98 +147,29 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
# =============================================================================
# UNet1D RL (Value Function) Model Tests
# =============================================================================
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
class UNet1DRLTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (RL value function variant)."""
@property
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
def model_class(self):
return UNet1DModel
@property
def output_shape(self):
return (4, 14, 1)
return (1,)
def test_determinism(self):
super().test_determinism()
@property
def main_input_name(self):
return "sample"
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
@@ -264,18 +185,54 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
@torch.no_grad()
def test_output(self):
# UNetRL is a value-function with different output shape (batch, 1)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
assert value_function is not None
assert len(vf_loading_info["missing_keys"]) == 0
value_function.to(torch_device)
image = value_function(**self.dummy_input)
image = value_function(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -299,31 +256,4 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
assert torch.allclose(output, expected_output_slice, rtol=1e-3)

View File

@@ -15,12 +15,11 @@
import gc
import math
import unittest
import pytest
import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from ...testing_utils import (
backend_empty_cache,
@@ -31,39 +30,40 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# Standard UNet2D Model Tests
# =============================================================================
class UNet2DTesterConfig(BaseModelTesterConfig):
"""Base configuration for standard UNet2DModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
@@ -74,11 +74,22 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_mid_block_attn_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 4
@@ -87,13 +98,11 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
assert model.mid_block.attentions[0].group_norm is not None, (
"Mid block Attention group norm should exist but does not."
)
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
"Mid block Attention group norm does not have the expected number of groups."
)
with torch.no_grad():
@@ -102,13 +111,15 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_mid_block_none(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
mid_none_init_dict = self.get_init_dict()
mid_none_inputs_dict = self.get_dummy_inputs()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
@@ -119,7 +130,7 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
mid_none_model.to(torch_device)
mid_none_model.eval()
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
assert mid_none_model.mid_block is None, "Mid block should not exist."
with torch.no_grad():
output = model(**inputs_dict)
@@ -133,8 +144,10 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
@@ -143,41 +156,32 @@ class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"UpBlock2D",
"DownBlock2D",
}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 8
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# UNet2D LDM Model Tests
# =============================================================================
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel LDM variant testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
@@ -187,17 +191,34 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.get_dummy_inputs()).sample
assert image is not None, "Make sure output is not None"
@@ -205,7 +226,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.dummy_input).sample
image = model(**self.get_dummy_inputs()).sample
assert image is not None, "Make sure output is not None"
@@ -265,44 +286,31 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 32
block_out_channels = (32, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
# =============================================================================
# NCSN++ Model Tests
# =============================================================================
class NCSNppTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel NCSN++ variant testing."""
@property
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
def model_class(self):
return UNet2DModel
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": [32, 64, 64, 64],
"in_channels": 3,
"layers_per_block": 1,
@@ -324,17 +332,71 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"SkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
}
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_keep_in_fp32_modules(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_from_save_pretrained_dtype_inference(self):
pass
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_training(self):
pass
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestNCSNppHubLoading(NCSNppTesterConfig):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
inputs = self.dummy_input
inputs = self.get_dummy_inputs()
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
@@ -361,7 +423,7 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
@@ -382,35 +444,4 @@ class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
# fmt: on
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# not required for this model
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
block_out_channels = (32, 64, 64, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, block_out_channels=block_out_channels
)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)

View File

@@ -20,6 +20,7 @@ import tempfile
import unittest
from collections import OrderedDict
import pytest
import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized
@@ -52,17 +53,24 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import (
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
UNetTesterMixin,
TrainingTesterMixin,
)
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
from ..testing_utils.lora import check_if_lora_correctly_set
logger = logging.get_logger(__name__)
@@ -82,16 +90,6 @@ def get_unet_lora_config():
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
@@ -354,34 +352,28 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.34, 0.4]
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DConditionModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
def model_class(self):
return UNet2DConditionModel
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int, int]:
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
def model_split_percents(self) -> list[float]:
return [0.5, 0.34, 0.4]
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self) -> str:
return "sample"
def get_init_dict(self) -> dict:
"""Return UNet2D model initialization arguments."""
return {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
@@ -393,26 +385,24 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
"""Return dummy inputs for UNet2D model."""
batch_size = 4
num_channels = 4
sizes = (16, 16)
model.enable_xformers_memory_efficient_attention()
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -427,12 +417,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["use_linear_projection"] = True
@@ -446,12 +437,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["cross_attention_dim"] = (8, 8)
@@ -465,12 +457,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_simple_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -489,12 +482,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_class_embeddings_concat(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -514,12 +508,287 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
"""Hub checkpoint loading tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for UNet2DConditionModel."""
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
"""Test that deprecated load_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
"""Test that deprecated save_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
model.save_attn_procs(os.path.join(tmpdirname))
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for UNet2DConditionModel."""
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
"""Training tests for UNet2DConditionModel."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
"""Attention processor tests for UNet2DConditionModel."""
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -544,7 +813,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert output is not None
def test_model_sliceable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -562,21 +831,6 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
@@ -618,7 +872,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -645,7 +900,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
]
)
def test_model_xattn_mask(self, mask_dtype):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
@@ -675,39 +931,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
)
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
"""Custom Diffusion processor tests for UNet2DConditionModel."""
def test_custom_diffusion_processors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -733,8 +963,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample1 - sample2).abs().max() < 3e-3
def test_custom_diffusion_save_load(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -754,7 +984,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
@@ -773,8 +1003,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_custom_diffusion_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -798,41 +1028,28 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for UNet2DConditionModel."""
model = self.model_class(**init_dict)
model.to(torch_device)
@property
def ip_adapter_processor_cls(self):
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
with torch.no_grad():
sample = model(**inputs_dict).sample
def create_ip_adapter_state_dict(self, model):
return create_ip_adapter_state_dict(model)
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
return inputs_dict
def test_ip_adapter(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -905,7 +1122,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -977,185 +1195,16 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)
warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
def test_torch_compile_repeated_blocks(self):
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for UNet2DConditionModel."""
@slow

View File

@@ -18,47 +18,44 @@ import unittest
import numpy as np
import torch
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers import UNet3DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
)
enable_full_determinism()
logger = logging.get_logger(__name__)
@skip_mps
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet3DConditionModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 4, 16, 16)
def model_class(self):
return UNet3DConditionModel
@property
def output_shape(self):
return (4, 4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
@@ -73,27 +70,25 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
model.enable_xformers_memory_efficient_attention()
return {
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
@@ -107,39 +102,74 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
# Overriding since the UNet3D outputs a different structure.
@torch.no_grad()
def test_determinism(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
assert max_diff <= 1e-5
def test_feed_forward_chunking(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
assert output.shape == output_2.shape, "Shape doesn't match"
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
@@ -162,22 +192,3 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2

View File

@@ -13,59 +13,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetControlNetXSModel testing."""
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
def model_class(self):
return UNetControlNetXSModel
@property
def output_shape(self):
return (4, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
@@ -80,11 +63,23 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
"conditioning_scale": 1,
}
def get_dummy_unet(self):
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
@@ -99,10 +94,16 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
"""Build the ControlNetXS-Adapter from a UNet."""
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
pass
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
@@ -115,7 +116,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
@@ -152,7 +153,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
@@ -193,12 +194,12 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
for p in module.parameters():
assert p.requires_grad
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
@@ -236,7 +237,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert_frozen(u.upsamplers)
# # check controlnet
# everything expect down,mid,up blocks
# everything except down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
@@ -267,16 +268,6 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):
unet = self.get_dummy_unet()
@@ -287,7 +278,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.dummy_input
input_ = self.get_dummy_inputs()
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
@@ -312,7 +303,7 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.dummy_input
input_ = self.get_dummy_inputs()
with torch.no_grad():
output = model(**input_).sample
@@ -320,7 +311,14 @@ class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Tes
assert output.shape == output_mix_time.shape
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -16,10 +16,10 @@
import copy
import unittest
import pytest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
@@ -28,45 +28,34 @@ from ...testing_utils import (
skip_mps,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
logger = logging.get_logger(__name__)
enable_full_determinism()
@skip_mps
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
@property
def dummy_input(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
@property
def input_shape(self):
return (2, 2, 4, 32, 32)
def model_class(self):
return UNetSpatioTemporalConditionModel
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
@property
def fps(self):
return 6
@@ -83,8 +72,8 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
def addition_time_embed_dim(self):
return 32
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self):
return {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
@@ -103,8 +92,23 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
"addition_time_embed_dim": self.addition_time_embed_dim,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
def _get_add_time_ids(self, do_classifier_free_guidance=True):
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
@@ -124,43 +128,15 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
return add_time_ids
@unittest.skip("Number of Norm Groups is not configurable")
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Number of Norm Groups is not configurable")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("Deprecated functionality")
def test_model_attention_slicing(self):
pass
@unittest.skip("Not supported")
def test_model_with_use_linear_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_simple_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_class_embeddings_concat(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_with_num_attention_heads_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
@@ -173,12 +149,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_model_with_cross_attention_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["cross_attention_dim"] = (32, 32)
@@ -192,27 +169,13 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
num_attention_heads = (8, 16)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, num_attention_heads=num_attention_heads
)
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["num_attention_heads"] = (8, 16)
@@ -225,3 +188,33 @@ class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, u
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -13,10 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import unittest
import warnings
import pytest
@@ -184,25 +182,6 @@ class DeprecateTester(unittest.TestCase):
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename
def test_deprecate_testing_utils_module(self):
import diffusers.utils.testing_utils
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
importlib.reload(diffusers.utils.testing_utils)
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
messages = [str(w.message) for w in deprecation_warnings]
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
)
assert any(
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
for msg in messages
), f"Expected deprecation message substring not found, got: {messages}"
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):