mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-25 01:48:21 +08:00
Compare commits
10 Commits
cosmos-fix
...
cuda-devic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4f83d1046 | ||
|
|
7bbd96da5d | ||
|
|
62777fa819 | ||
|
|
f1fd515257 | ||
|
|
afdda57f61 | ||
|
|
5fc2bd2c8f | ||
|
|
6350a7690a | ||
|
|
9d4c9dcf21 | ||
|
|
ef309a1bb0 | ||
|
|
b9761ce5a2 |
@@ -446,6 +446,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoder_kl_kvae
|
||||
title: AutoencoderKLKVAE
|
||||
- local: api/models/autoencoder_kl_kvae_video
|
||||
title: AutoencoderKLKVAEVideo
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
|
||||
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAE
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAE
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAE
|
||||
- decode
|
||||
- all
|
||||
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAEVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAEVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAEVideo
|
||||
- decode
|
||||
- all
|
||||
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLKVAE",
|
||||
"AutoencoderKLKVAEVideo",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
@@ -975,6 +977,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -40,6 +40,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
@@ -161,6 +163,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -587,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
|
||||
@@ -9,6 +9,8 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_kvae import AutoencoderKLKVAE
|
||||
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanImageRefinerAttnBlock(nn.Module):
|
||||
|
||||
@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class HunyuanVideo15AttnBlock(nn.Module):
|
||||
|
||||
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class KVAEResnetBlock2D(nn.Module):
|
||||
r"""
|
||||
A Resnet block with optional guidance.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to `None`):
|
||||
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
||||
conv_shortcut (`bool`, *optional*, default to `False`):
|
||||
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
|
||||
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
|
||||
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
|
||||
add_conv (`bool`, *optional*, default to `False`):
|
||||
If `True` add conv2d layer for normalization.
|
||||
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 512,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if zq_ch is None:
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
h = self.norm1(h)
|
||||
else:
|
||||
h = self.norm1(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h)
|
||||
else:
|
||||
h = self.norm2(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAEPXSDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
A Downsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The downsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (bchw)
|
||||
pxs_interm = self.unshuffle(x)
|
||||
b, c, h, w = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
|
||||
conv_out = self.spatial_conv(x)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_out
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEPXSUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
An Upsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The upsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
repeated = x.repeat_interleave(self.factor**2, dim=1)
|
||||
pxs_interm = self.shuffle(repeated)
|
||||
|
||||
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
conv_out = self.spatial_conv(image_like_ups)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_interm
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEDecoderSpatialNorm2D(nn.Module):
|
||||
r"""
|
||||
A 2D normalization module for decoder.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
zq_channels (`int`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`):
|
||||
If `True` add conv2d 3x3 layer for guidance in the beginning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.add_conv = add_conv
|
||||
if add_conv:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=zq_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
|
||||
self.conv_y = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_b = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
f_first = f
|
||||
f_first_size = f_first.shape[2:]
|
||||
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
class KVAEEncoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D encoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of output channels.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
double_z: bool = True,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
|
||||
else:
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.ch,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level < self.num_resolutions - 1:
|
||||
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in,
|
||||
out_channels=2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAEDecoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D decoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
out_ch (`int`): The number of output channels.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of input channels.
|
||||
give_pre_end (`bool`, *optional*, default to `false`):
|
||||
If `True` exit the forward pass early and return the penultimate feature map.
|
||||
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
give_pre_end: bool = False,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = KVAEPXSUpsample(in_channels=block_in)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
|
||||
num_enc_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in encoder multiresolution layers.
|
||||
num_dec_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in decoder multiresolution layers.
|
||||
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels of encoder.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 128,
|
||||
num_enc_blocks: int = 2,
|
||||
num_dec_blocks: int = 2,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
sample_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = KVAEEncoder2D(
|
||||
in_channels=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_enc_blocks,
|
||||
z_channels=z_channels,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = KVAEDecoder2D(
|
||||
out_ch=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_dec_blocks,
|
||||
in_channels=None,
|
||||
z_channels=z_channels,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
@@ -0,0 +1,954 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAESafeConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
|
||||
memory_count = input.numel() * input.element_size() / (10**9)
|
||||
|
||||
if memory_count > 3:
|
||||
kernel_size = self.kernel_size[0]
|
||||
part_num = math.ceil(memory_count / 2)
|
||||
input_chunks = torch.chunk(input, part_num, dim=2)
|
||||
|
||||
if write_to is None:
|
||||
output = []
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
output.append(super().forward(z))
|
||||
return torch.cat(output, dim=2)
|
||||
else:
|
||||
time_offset = 0
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
z_time = z.size(2) - (kernel_size - 1)
|
||||
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
|
||||
time_offset += z_time
|
||||
return write_to
|
||||
else:
|
||||
if write_to is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
write_to[...] = super().forward(input)
|
||||
return write_to
|
||||
|
||||
|
||||
class KVAECausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
|
||||
input_padded = F.pad(input, padding_3d, mode="replicate")
|
||||
return self.conv(input_padded)
|
||||
|
||||
|
||||
class KVAECachedCausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer with caching for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
t_stride = self.stride[0]
|
||||
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
|
||||
input_parallel = F.pad(input, padding_3d, mode="replicate")
|
||||
|
||||
if cache["padding"] is None:
|
||||
first_frame = input_parallel[:, :, :1]
|
||||
time_pad_shape = list(first_frame.shape)
|
||||
time_pad_shape[2] = self.time_pad
|
||||
padding = first_frame.expand(time_pad_shape)
|
||||
else:
|
||||
padding = cache["padding"]
|
||||
|
||||
out_size = list(input.shape)
|
||||
out_size[1] = self.conv.out_channels
|
||||
if t_stride == 2:
|
||||
out_size[2] = (input.size(2) + 1) // 2
|
||||
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
|
||||
|
||||
offset_out = math.ceil(padding.size(2) / t_stride)
|
||||
offset_in = offset_out * t_stride - padding.size(2)
|
||||
|
||||
if offset_out > 0:
|
||||
padding_poisoned = torch.cat(
|
||||
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
|
||||
)
|
||||
output[:, :, :offset_out] = self.conv(padding_poisoned)
|
||||
|
||||
if offset_out < output.size(2):
|
||||
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
|
||||
|
||||
pad_offset = (
|
||||
offset_in
|
||||
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
|
||||
+ t_stride
|
||||
)
|
||||
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class KVAECachedGroupNorm(nn.Module):
|
||||
r"""
|
||||
GroupNorm with caching support for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
|
||||
out = self.norm_layer(x)
|
||||
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
|
||||
cache["mean"] = 1
|
||||
cache["var"] = 1
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedSpatialNorm3D(nn.Module):
|
||||
r"""
|
||||
Spatially conditioned normalization for decoder with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = KVAECachedGroupNorm(f_channels)
|
||||
self.add_conv = add_conv
|
||||
|
||||
if add_conv:
|
||||
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
|
||||
|
||||
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
|
||||
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
|
||||
if zq.size(2) > 1:
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
]
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = zq_first
|
||||
else:
|
||||
f_size = f.shape[-3:]
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq, cache["add_conv"])
|
||||
|
||||
norm_f = self.norm_layer(f, cache["norm"])
|
||||
norm_f = norm_f * self.conv_y(zq)
|
||||
norm_f = norm_f + self.conv_b(zq)
|
||||
|
||||
return norm_f
|
||||
|
||||
|
||||
class KVAECachedResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 0,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
gather_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = KVAECachedGroupNorm(in_channels)
|
||||
else:
|
||||
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm2 = KVAECachedGroupNorm(out_channels)
|
||||
else:
|
||||
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
else:
|
||||
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
# Encoder path - norm takes cache
|
||||
h = self.norm1(h, cache=layer_cache["norm1"])
|
||||
else:
|
||||
# Decoder path - spatial norm takes zq and cache
|
||||
h = self.norm1(h, zq, cache=layer_cache["norm1"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv1(h, cache=layer_cache["conv1"])
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h, cache=layer_cache["norm2"])
|
||||
else:
|
||||
h = self.norm2(h, zq, cache=layer_cache["norm2"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv2(h, cache=layer_cache["conv2"])
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAECachedPXSDownsample(nn.Module):
|
||||
r"""
|
||||
A 3D downsampling layer using PixelUnshuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 2, 2),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
|
||||
pxs_interm = self.unshuffle(pxs_input)
|
||||
b_it, c_it, h_it, w_it = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
|
||||
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
|
||||
conv_out = self.spatial_conv(input)
|
||||
return conv_out + pxs_out
|
||||
|
||||
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
|
||||
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
|
||||
if cache[0]["padding"] is None:
|
||||
first, rest = permuted[..., :1], permuted[..., 1:]
|
||||
if rest.size(-1) > 0:
|
||||
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
full_interp = torch.cat([first, rest_interp], dim=-1)
|
||||
else:
|
||||
full_interp = first
|
||||
else:
|
||||
rest = permuted
|
||||
if rest.size(-1) > 0:
|
||||
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
|
||||
t_new = full_interp.size(-1)
|
||||
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
|
||||
conv_out = self.temporal_conv(input, cache[0])
|
||||
return conv_out + full_interp
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
out = self.spatial_downsample(x)
|
||||
|
||||
if self.temporal_compress:
|
||||
out = self.temporal_downsample(out, cache=cache)
|
||||
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAECachedPXSUpsample(nn.Module):
|
||||
r"""
|
||||
A 3D upsampling layer using PixelShuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
|
||||
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
|
||||
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
|
||||
|
||||
out = self.spatial_conv(input_interp)
|
||||
return input_interp + out
|
||||
|
||||
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
|
||||
if isinstance(time_factor, torch.Tensor):
|
||||
time_factor = time_factor.item()
|
||||
|
||||
repeated = input.repeat_interleave(int(time_factor), dim=2)
|
||||
|
||||
if cache["padding"] is None:
|
||||
tail = repeated[..., int(time_factor - 1) :, :, :]
|
||||
else:
|
||||
tail = repeated
|
||||
|
||||
conv_out = self.temporal_conv(tail, cache)
|
||||
return conv_out + tail
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if self.temporal_compress:
|
||||
x = self.temporal_upsample(x, cache)
|
||||
|
||||
s_out = self.spatial_upsample(x)
|
||||
to = torch.empty_like(s_out)
|
||||
lin_out = self.linear(s_out, write_to=to)
|
||||
return lin_out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached Encoder/Decoder
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedEncoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Encoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
in_channels: int = 3,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
block_in = ch
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dropout=dropout,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
|
||||
else:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
|
||||
self.down.append(down)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm_out = KVAECachedGroupNorm(block_in)
|
||||
self.conv_out = KVAECachedCausalConv3d(
|
||||
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
|
||||
h = self.conv_in(x, cache=cache_dict["conv_in"])
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
|
||||
)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
|
||||
|
||||
h = self.norm_out(h, cache=cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache=cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAECachedDecoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Decoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
out_ch: int = 3,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
z_channels: int = 16,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
|
||||
else:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
zq = z
|
||||
|
||||
h = self.conv_in(z, cache_dict["conv_in"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
|
||||
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
|
||||
)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
|
||||
|
||||
h = self.norm_out(h, zq, cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main AutoencoderKL class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
|
||||
[KVAE](https://github.com/kandinskylab/kvae-1).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
ch (`int`, *optional*, defaults to 128): Base channel count.
|
||||
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
|
||||
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
|
||||
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
|
||||
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["KVAECachedResnetBlock3D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
in_channels: int = 3,
|
||||
out_ch: int = 3,
|
||||
z_channels: int = 16,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = KVAECachedEncoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
in_channels=in_channels,
|
||||
z_channels=z_channels,
|
||||
double_z=True,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.decoder = KVAECachedDecoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
out_ch=out_ch,
|
||||
z_channels=z_channels,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
def _make_encoder_cache(self) -> Dict:
|
||||
"""Create empty cache for cached encoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_enc"),
|
||||
"mid_2": make_dict("resblock_enc"),
|
||||
"norm_out": make_dict("norm_enc"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
# Encoder uses num_res_blocks per level
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
|
||||
return cache
|
||||
|
||||
def _make_decoder_cache(self) -> Dict:
|
||||
"""Create empty cache for decoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_dec"),
|
||||
"mid_2": make_dict("resblock_dec"),
|
||||
"norm_out": make_dict("norm_dec"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
|
||||
return cache
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""Enable sliced VAE decoding."""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""Disable sliced VAE decoding."""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
# Cached encoder processes by segments
|
||||
cache = self._make_encoder_cache()
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = x.size(2) - (seg_len + 1)
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
|
||||
latent = []
|
||||
for chunk in torch.split(x, split_list, dim=2):
|
||||
l = self.encoder(chunk, cache)
|
||||
sample, _ = torch.chunk(l, 2, dim=1)
|
||||
latent.append(sample)
|
||||
|
||||
return torch.cat(latent, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of videos into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
# For cached encoder, we already did the split in _encode
|
||||
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
|
||||
posterior = DiagonalGaussianDistribution(h_double)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
cache = self._make_decoder_cache()
|
||||
temporal_compress = self.config.temporal_compress_times
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
split_list = [math.ceil(size / temporal_compress) for size in split_list]
|
||||
|
||||
recs = []
|
||||
for chunk in torch.split(z, split_list, dim=2):
|
||||
out = self.decoder(chunk, cache)
|
||||
recs.append(out)
|
||||
|
||||
return torch.cat(recs, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of videos.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class QwenImageUpsample(nn.Upsample):
|
||||
|
||||
@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
|
||||
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
||||
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
|
||||
t in str(x.dtype) for t in ("float4_", "float8_")
|
||||
)
|
||||
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
|
||||
x.dtype
|
||||
)
|
||||
|
||||
return normalized * self.scale * self.gamma + self.bias
|
||||
|
||||
|
||||
class WanUpsample(nn.Upsample):
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -934,6 +933,7 @@ class QwenImageTransformer2DModel(
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
joint_attention_mask = joint_attention_mask[:, None, None, :]
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@@ -16,22 +16,29 @@ from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
|
||||
from ...schedulers import UniPCMultistepScheduler
|
||||
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import (
|
||||
is_cosmos_guardrail_available,
|
||||
is_torch_xla_available,
|
||||
is_torchvision_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import CosmosPipelineOutput
|
||||
|
||||
|
||||
if is_torchvision_available():
|
||||
import torchvision.transforms.functional
|
||||
|
||||
|
||||
if is_cosmos_guardrail_available():
|
||||
from cosmos_guardrail import CosmosSafetyChecker
|
||||
else:
|
||||
|
||||
@@ -521,6 +521,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -67,9 +68,11 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,23 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
if is_torch_version("<", "2.7.0"):
|
||||
return cached
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"channels": 32,
|
||||
"num_enc_blocks": 1,
|
||||
"num_dec_blocks": 1,
|
||||
"z_channels": 4,
|
||||
"double_z": True,
|
||||
"ch_mult": (1, 2),
|
||||
"sample_size": 32,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAEEncoder2D",
|
||||
"KVAEDecoder2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAEVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_video_config(self):
|
||||
return {
|
||||
"ch": 32,
|
||||
"ch_mult": (1, 2),
|
||||
"num_res_blocks": 1,
|
||||
"in_channels": 3,
|
||||
"out_ch": 3,
|
||||
"z_channels": 4,
|
||||
"temporal_compress_times": 2,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": video}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAECachedEncoder3D",
|
||||
"KVAECachedDecoder3D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_sharded_checkpoints_device_map(self):
|
||||
pass
|
||||
|
||||
def _run_nondeterministic(self, fn):
|
||||
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
|
||||
# temporarily relax the requirement for training tests that do backward passes.
|
||||
import torch
|
||||
|
||||
torch.use_deterministic_algorithms(False)
|
||||
try:
|
||||
fn()
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def test_training(self):
|
||||
self._run_nondeterministic(super().test_training)
|
||||
|
||||
def test_ema_training(self):
|
||||
self._run_nondeterministic(super().test_ema_training)
|
||||
|
||||
@unittest.skip(
|
||||
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
|
||||
"that is mutated during the first forward. On recomputation the cache is already populated, "
|
||||
"causing a different execution path and numerically different gradients. "
|
||||
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
|
||||
)
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
self._run_nondeterministic(super().test_layerwise_casting_training)
|
||||
@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
@@ -194,6 +200,10 @@ class ContextParallelTesterMixin:
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
@@ -209,6 +219,11 @@ class ContextParallelTesterMixin:
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
|
||||
@@ -150,8 +150,7 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
|
||||
@@ -90,8 +90,7 @@ class Flux2TransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
@@ -77,8 +78,7 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = embedding_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
@@ -106,9 +106,10 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_infers_text_seq_len_from_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -122,7 +123,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert normalized_mask.sum().item() == 2 * batch_size
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
@@ -139,7 +140,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
assert normalized_mask2.sum().item() == 5 * batch_size
|
||||
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
@@ -149,9 +150,10 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_non_contiguous_attention_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -284,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_linear(self):
|
||||
super().test_hotswapping_compiled_model_linear()
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self):
|
||||
super().test_hotswapping_compiled_model_both_linear_and_other()
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
|
||||
@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
|
||||
|
||||
|
||||
def fetch_pipeline_objects():
|
||||
models = api.list_models(library="diffusers")
|
||||
models = api.list_models(filter="diffusers")
|
||||
downloads = defaultdict(int)
|
||||
|
||||
for model in models:
|
||||
|
||||
Reference in New Issue
Block a user