mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-24 17:38:15 +08:00
Compare commits
11 Commits
tests-cond
...
fix-lora-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01acf7216c | ||
|
|
4f9330678c | ||
|
|
6350a7690a | ||
|
|
9d4c9dcf21 | ||
|
|
ef309a1bb0 | ||
|
|
b9761ce5a2 | ||
|
|
52558b45d8 | ||
|
|
c02c17c6ee | ||
|
|
a9855c4204 | ||
|
|
0b35834351 | ||
|
|
522b523e40 |
@@ -446,6 +446,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoder_kl_kvae
|
||||
title: AutoencoderKLKVAE
|
||||
- local: api/models/autoencoder_kl_kvae_video
|
||||
title: AutoencoderKLKVAEVideo
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
|
||||
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
32
docs/source/en/api/models/autoencoder_kl_kvae.md
Normal file
@@ -0,0 +1,32 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAE
|
||||
|
||||
The 2D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAE
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAE
|
||||
- decode
|
||||
- all
|
||||
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
33
docs/source/en/api/models/autoencoder_kl_kvae_video.md
Normal file
@@ -0,0 +1,33 @@
|
||||
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. -->
|
||||
|
||||
# AutoencoderKLKVAEVideo
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
## AutoencoderKLKVAEVideo
|
||||
|
||||
[[autodoc]] AutoencoderKLKVAEVideo
|
||||
- decode
|
||||
- all
|
||||
|
||||
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLKVAE",
|
||||
"AutoencoderKLKVAEVideo",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
@@ -975,6 +977,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -40,6 +40,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
|
||||
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
@@ -161,6 +163,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLKVAE,
|
||||
AutoencoderKLKVAEVideo,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -575,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -9,6 +9,8 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
||||
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_kvae import AutoencoderKLKVAE
|
||||
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
|
||||
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
802
src/diffusers/models/autoencoders/autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
class KVAEResnetBlock2D(nn.Module):
|
||||
r"""
|
||||
A Resnet block with optional guidance.
|
||||
|
||||
Parameters:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
out_channels (`int`, *optional*, default to `None`):
|
||||
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
|
||||
conv_shortcut (`bool`, *optional*, default to `False`):
|
||||
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
|
||||
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
|
||||
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
|
||||
add_conv (`bool`, *optional*, default to `False`):
|
||||
If `True` add conv2d layer for normalization.
|
||||
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
temb_channels: int = 512,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if zq_ch is None:
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
else:
|
||||
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
h = self.norm1(h)
|
||||
else:
|
||||
h = self.norm1(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h)
|
||||
else:
|
||||
h = self.norm2(h, zq)
|
||||
|
||||
h = self.nonlinearity(h)
|
||||
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAEPXSDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
A Downsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The downsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x: (bchw)
|
||||
pxs_interm = self.unshuffle(x)
|
||||
b, c, h, w = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
|
||||
conv_out = self.spatial_conv(x)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_out
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEPXSUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, factor: int = 2):
|
||||
r"""
|
||||
An Upsampling module.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
factor (`int`, *optional*, default to `2`): The upsampling factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
self.spatial_conv = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
|
||||
)
|
||||
|
||||
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
repeated = x.repeat_interleave(self.factor**2, dim=1)
|
||||
pxs_interm = self.shuffle(repeated)
|
||||
|
||||
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
conv_out = self.spatial_conv(image_like_ups)
|
||||
|
||||
# adding it all together
|
||||
out = conv_out + pxs_interm
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAEDecoderSpatialNorm2D(nn.Module):
|
||||
r"""
|
||||
A 2D normalization module for decoder.
|
||||
|
||||
Args:
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
zq_channels (`int`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`):
|
||||
If `True` add conv2d 3x3 layer for guidance in the beginning.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.add_conv = add_conv
|
||||
if add_conv:
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=zq_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
padding_mode="replicate",
|
||||
)
|
||||
|
||||
self.conv_y = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
self.conv_b = nn.Conv2d(
|
||||
in_channels=zq_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
f_first = f
|
||||
f_first_size = f_first.shape[2:]
|
||||
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq)
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
class KVAEEncoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D encoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of output channels.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels for the last block.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
double_z: bool = True,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
|
||||
else:
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=self.ch,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level < self.num_resolutions - 1:
|
||||
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
|
||||
# end
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in,
|
||||
out_channels=2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
padding=(1, 1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks[i_level]):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAEDecoder2D(nn.Module):
|
||||
r"""
|
||||
A 2D decoder module.
|
||||
|
||||
Args:
|
||||
ch (`int`): The base number of channels in multiresolution blocks.
|
||||
out_ch (`int`): The number of output channels.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
num_res_blocks (`int`): The number of Resnet blocks.
|
||||
in_channels (`int`): The number of channels in the input.
|
||||
z_channels (`int`): The number of input channels.
|
||||
give_pre_end (`bool`, *optional*, default to `false`):
|
||||
If `True` exit the forward pass early and return the penultimate feature map.
|
||||
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
|
||||
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
|
||||
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ch: int,
|
||||
out_ch: int,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int,
|
||||
in_channels: int,
|
||||
z_channels: int,
|
||||
give_pre_end: bool = False,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
act_fn: str = "swish",
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.nonlinearity = get_activation(act_fn)
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.mid.block_2 = KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAEResnetBlock2D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = KVAEPXSUpsample(in_channels=block_in)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
zq = z
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, zq)
|
||||
h = self.mid.block_2(h, temb, zq)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h, zq)
|
||||
h = self.nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
|
||||
num_enc_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in encoder multiresolution layers.
|
||||
num_dec_blocks (int, *optional*, defaults to 2):
|
||||
The number of Resnet blocks in decoder multiresolution layers.
|
||||
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
|
||||
double_z (`bool`, *optional*, defaults to `True`):
|
||||
Whether to double the number of output channels of encoder.
|
||||
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
|
||||
The channel multipliers in multiresolution blocks.
|
||||
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
channels: int = 128,
|
||||
num_enc_blocks: int = 2,
|
||||
num_dec_blocks: int = 2,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
sample_size: int = 1024,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = KVAEEncoder2D(
|
||||
in_channels=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_enc_blocks,
|
||||
z_channels=z_channels,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = KVAEDecoder2D(
|
||||
out_ch=in_channels,
|
||||
ch=channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_dec_blocks,
|
||||
in_channels=None,
|
||||
z_channels=z_channels,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
|
||||
return self._tiled_encode(x)
|
||||
|
||||
enc = self.encoder(x)
|
||||
|
||||
return enc
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of images into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
||||
for y in range(blend_extent):
|
||||
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
||||
return b
|
||||
|
||||
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
||||
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
||||
for x in range(blend_extent):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
|
||||
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
||||
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
||||
output, but they should be much less noticeable.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The latent representation of the encoded videos.
|
||||
"""
|
||||
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_latent_min_size - blend_extent
|
||||
|
||||
# Split the image into 512x512 tiles and encode them separately.
|
||||
rows = []
|
||||
for i in range(0, x.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, x.shape[3], overlap_size):
|
||||
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
||||
tile = self.encoder(tile)
|
||||
row.append(tile)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
enc = torch.cat(result_rows, dim=2)
|
||||
return enc
|
||||
|
||||
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Decode a batch of images using a tiled decoder.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
||||
row_limit = self.tile_sample_min_size - blend_extent
|
||||
|
||||
# Split z into overlapping 64x64 tiles and decode them separately.
|
||||
# The tiles have an overlap to avoid seams between tiles.
|
||||
rows = []
|
||||
for i in range(0, z.shape[2], overlap_size):
|
||||
row = []
|
||||
for j in range(0, z.shape[3], overlap_size):
|
||||
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
||||
decoded = self.decoder(tile)
|
||||
row.append(decoded)
|
||||
rows.append(row)
|
||||
result_rows = []
|
||||
for i, row in enumerate(rows):
|
||||
result_row = []
|
||||
for j, tile in enumerate(row):
|
||||
# blend the above tile and the left tile
|
||||
# to the current tile and add the current tile to the result row
|
||||
if i > 0:
|
||||
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
||||
if j > 0:
|
||||
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
||||
result_row.append(tile[:, :, :row_limit, :row_limit])
|
||||
result_rows.append(torch.cat(result_row, dim=3))
|
||||
|
||||
dec = torch.cat(result_rows, dim=2)
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.Tensor`): Input sample.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
954
src/diffusers/models/autoencoders/autoencoder_kl_kvae_video.py
Normal file
@@ -0,0 +1,954 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAESafeConv3d(nn.Conv3d):
|
||||
r"""
|
||||
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
|
||||
memory_count = input.numel() * input.element_size() / (10**9)
|
||||
|
||||
if memory_count > 3:
|
||||
kernel_size = self.kernel_size[0]
|
||||
part_num = math.ceil(memory_count / 2)
|
||||
input_chunks = torch.chunk(input, part_num, dim=2)
|
||||
|
||||
if write_to is None:
|
||||
output = []
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
output.append(super().forward(z))
|
||||
return torch.cat(output, dim=2)
|
||||
else:
|
||||
time_offset = 0
|
||||
for i, chunk in enumerate(input_chunks):
|
||||
if i == 0 or kernel_size == 1:
|
||||
z = torch.clone(chunk)
|
||||
else:
|
||||
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
|
||||
z_time = z.size(2) - (kernel_size - 1)
|
||||
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
|
||||
time_offset += z_time
|
||||
return write_to
|
||||
else:
|
||||
if write_to is None:
|
||||
return super().forward(input)
|
||||
else:
|
||||
write_to[...] = super().forward(input)
|
||||
return write_to
|
||||
|
||||
|
||||
class KVAECausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
|
||||
input_padded = F.pad(input, padding_3d, mode="replicate")
|
||||
return self.conv(input_padded)
|
||||
|
||||
|
||||
class KVAECachedCausalConv3d(nn.Module):
|
||||
r"""
|
||||
A 3D causal convolution layer with caching for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: Union[int, Tuple[int, int, int]],
|
||||
stride: Tuple[int, int, int] = (1, 1, 1),
|
||||
dilation: Tuple[int, int, int] = (1, 1, 1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||
|
||||
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
|
||||
|
||||
self.height_pad = height_kernel_size // 2
|
||||
self.width_pad = width_kernel_size // 2
|
||||
self.time_pad = time_kernel_size - 1
|
||||
self.time_kernel_size = time_kernel_size
|
||||
self.stride = stride
|
||||
|
||||
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
t_stride = self.stride[0]
|
||||
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
|
||||
input_parallel = F.pad(input, padding_3d, mode="replicate")
|
||||
|
||||
if cache["padding"] is None:
|
||||
first_frame = input_parallel[:, :, :1]
|
||||
time_pad_shape = list(first_frame.shape)
|
||||
time_pad_shape[2] = self.time_pad
|
||||
padding = first_frame.expand(time_pad_shape)
|
||||
else:
|
||||
padding = cache["padding"]
|
||||
|
||||
out_size = list(input.shape)
|
||||
out_size[1] = self.conv.out_channels
|
||||
if t_stride == 2:
|
||||
out_size[2] = (input.size(2) + 1) // 2
|
||||
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
|
||||
|
||||
offset_out = math.ceil(padding.size(2) / t_stride)
|
||||
offset_in = offset_out * t_stride - padding.size(2)
|
||||
|
||||
if offset_out > 0:
|
||||
padding_poisoned = torch.cat(
|
||||
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
|
||||
)
|
||||
output[:, :, :offset_out] = self.conv(padding_poisoned)
|
||||
|
||||
if offset_out < output.size(2):
|
||||
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
|
||||
|
||||
pad_offset = (
|
||||
offset_in
|
||||
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
|
||||
+ t_stride
|
||||
)
|
||||
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class KVAECachedGroupNorm(nn.Module):
|
||||
r"""
|
||||
GroupNorm with caching support for temporal processing.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
|
||||
out = self.norm_layer(x)
|
||||
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
|
||||
cache["mean"] = 1
|
||||
cache["var"] = 1
|
||||
return out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached layers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedSpatialNorm3D(nn.Module):
|
||||
r"""
|
||||
Spatially conditioned normalization for decoder with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
f_channels: int,
|
||||
zq_channels: int,
|
||||
add_conv: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_layer = KVAECachedGroupNorm(f_channels)
|
||||
self.add_conv = add_conv
|
||||
|
||||
if add_conv:
|
||||
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
|
||||
|
||||
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
|
||||
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
|
||||
|
||||
if zq.size(2) > 1:
|
||||
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
|
||||
interpolated_splits = [
|
||||
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
|
||||
]
|
||||
zq_rest = torch.cat(interpolated_splits, dim=1)
|
||||
zq = torch.cat([zq_first, zq_rest], dim=2)
|
||||
else:
|
||||
zq = zq_first
|
||||
else:
|
||||
f_size = f.shape[-3:]
|
||||
zq_splits = torch.split(zq, 32, dim=1)
|
||||
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
|
||||
zq = torch.cat(interpolated_splits, dim=1)
|
||||
|
||||
if self.add_conv:
|
||||
zq = self.conv(zq, cache["add_conv"])
|
||||
|
||||
norm_f = self.norm_layer(f, cache["norm"])
|
||||
norm_f = norm_f * self.conv_y(zq)
|
||||
norm_f = norm_f + self.conv_b(zq)
|
||||
|
||||
return norm_f
|
||||
|
||||
|
||||
class KVAECachedResnetBlock3D(nn.Module):
|
||||
r"""
|
||||
A 3D ResNet block with caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 0,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
gather_norm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm1 = KVAECachedGroupNorm(in_channels)
|
||||
else:
|
||||
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
if zq_ch is None:
|
||||
self.norm2 = KVAECachedGroupNorm(out_channels)
|
||||
else:
|
||||
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
|
||||
|
||||
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
|
||||
else:
|
||||
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
|
||||
h = x
|
||||
|
||||
if zq is None:
|
||||
# Encoder path - norm takes cache
|
||||
h = self.norm1(h, cache=layer_cache["norm1"])
|
||||
else:
|
||||
# Decoder path - spatial norm takes zq and cache
|
||||
h = self.norm1(h, zq, cache=layer_cache["norm1"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv1(h, cache=layer_cache["conv1"])
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is None:
|
||||
h = self.norm2(h, cache=layer_cache["norm2"])
|
||||
else:
|
||||
h = self.norm2(h, zq, cache=layer_cache["norm2"])
|
||||
|
||||
h = F.silu(h)
|
||||
h = self.conv2(h, cache=layer_cache["conv2"])
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class KVAECachedPXSDownsample(nn.Module):
|
||||
r"""
|
||||
A 3D downsampling layer using PixelUnshuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.unshuffle = nn.PixelUnshuffle(self.factor)
|
||||
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 2, 2),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
|
||||
pxs_interm = self.unshuffle(pxs_input)
|
||||
b_it, c_it, h_it, w_it = pxs_interm.shape
|
||||
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
|
||||
pxs_out = torch.mean(pxs_interm_view, dim=2)
|
||||
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
|
||||
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
|
||||
conv_out = self.spatial_conv(input)
|
||||
return conv_out + pxs_out
|
||||
|
||||
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
|
||||
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
|
||||
if cache[0]["padding"] is None:
|
||||
first, rest = permuted[..., :1], permuted[..., 1:]
|
||||
if rest.size(-1) > 0:
|
||||
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
full_interp = torch.cat([first, rest_interp], dim=-1)
|
||||
else:
|
||||
full_interp = first
|
||||
else:
|
||||
rest = permuted
|
||||
if rest.size(-1) > 0:
|
||||
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
|
||||
|
||||
t_new = full_interp.size(-1)
|
||||
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
|
||||
conv_out = self.temporal_conv(input, cache[0])
|
||||
return conv_out + full_interp
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
|
||||
out = self.spatial_downsample(x)
|
||||
|
||||
if self.temporal_compress:
|
||||
out = self.temporal_downsample(out, cache=cache)
|
||||
|
||||
return self.linear(out)
|
||||
|
||||
|
||||
class KVAECachedPXSUpsample(nn.Module):
|
||||
r"""
|
||||
A 3D upsampling layer using PixelShuffle with caching.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
|
||||
super().__init__()
|
||||
self.temporal_compress = compress_time
|
||||
self.factor = factor
|
||||
self.shuffle = nn.PixelShuffle(self.factor)
|
||||
|
||||
self.spatial_conv = KVAESafeConv3d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=(1, 3, 3),
|
||||
stride=(1, 1, 1),
|
||||
padding=(0, 1, 1),
|
||||
padding_mode="reflect",
|
||||
)
|
||||
|
||||
if self.temporal_compress:
|
||||
self.temporal_conv = KVAECachedCausalConv3d(
|
||||
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
|
||||
)
|
||||
|
||||
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
|
||||
|
||||
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
|
||||
b, c, t, h, w = input.shape
|
||||
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
|
||||
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
|
||||
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
|
||||
|
||||
out = self.spatial_conv(input_interp)
|
||||
return input_interp + out
|
||||
|
||||
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
|
||||
if isinstance(time_factor, torch.Tensor):
|
||||
time_factor = time_factor.item()
|
||||
|
||||
repeated = input.repeat_interleave(int(time_factor), dim=2)
|
||||
|
||||
if cache["padding"] is None:
|
||||
tail = repeated[..., int(time_factor - 1) :, :, :]
|
||||
else:
|
||||
tail = repeated
|
||||
|
||||
conv_out = self.temporal_conv(tail, cache)
|
||||
return conv_out + tail
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
|
||||
if self.temporal_compress:
|
||||
x = self.temporal_upsample(x, cache)
|
||||
|
||||
s_out = self.spatial_upsample(x)
|
||||
to = torch.empty_like(s_out)
|
||||
lin_out = self.linear(s_out, write_to=to)
|
||||
return lin_out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cached Encoder/Decoder
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class KVAECachedEncoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Encoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
in_channels: int = 3,
|
||||
z_channels: int = 16,
|
||||
double_z: bool = True,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.in_channels = in_channels
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
|
||||
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
block_in = ch
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
dropout=dropout,
|
||||
temb_channels=self.temb_ch,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
|
||||
if i_level != self.num_resolutions - 1:
|
||||
if i_level < self.temporal_compress_level:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
|
||||
else:
|
||||
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
|
||||
self.down.append(down)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
|
||||
)
|
||||
|
||||
self.norm_out = KVAECachedGroupNorm(block_in)
|
||||
self.conv_out = KVAECachedCausalConv3d(
|
||||
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
|
||||
h = self.conv_in(x, cache=cache_dict["conv_in"])
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
|
||||
)
|
||||
else:
|
||||
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
|
||||
|
||||
h = self.norm_out(h, cache=cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache=cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class KVAECachedDecoder3D(nn.Module):
|
||||
r"""
|
||||
Cached 3D Decoder for KVAE.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
out_ch: int = 3,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
dropout: float = 0.0,
|
||||
z_channels: int = 16,
|
||||
zq_ch: Optional[int] = None,
|
||||
add_conv: bool = False,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.temporal_compress_level = int(np.log2(temporal_compress_times))
|
||||
|
||||
if zq_ch is None:
|
||||
zq_ch = z_channels
|
||||
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
|
||||
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
self.mid.block_2 = KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
KVAECachedResnetBlock3D(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
zq_ch=zq_ch,
|
||||
add_conv=add_conv,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
|
||||
if i_level != 0:
|
||||
if i_level < self.num_resolutions - self.temporal_compress_level:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
|
||||
else:
|
||||
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
|
||||
self.up.insert(0, up)
|
||||
|
||||
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
|
||||
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
|
||||
temb = None
|
||||
zq = z
|
||||
|
||||
h = self.conv_in(z, cache_dict["conv_in"])
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
|
||||
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
|
||||
else:
|
||||
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
|
||||
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
|
||||
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
h = self._gradient_checkpointing_func(
|
||||
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
|
||||
)
|
||||
else:
|
||||
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, zq)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
|
||||
|
||||
h = self.norm_out(h, zq, cache_dict["norm_out"])
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, cache_dict["conv_out"])
|
||||
|
||||
return h
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Main AutoencoderKL class
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
|
||||
[KVAE](https://github.com/kandinskylab/kvae-1).
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
ch (`int`, *optional*, defaults to 128): Base channel count.
|
||||
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
|
||||
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
|
||||
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
|
||||
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
|
||||
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
|
||||
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["KVAECachedResnetBlock3D"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
ch: int = 128,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
|
||||
num_res_blocks: int = 2,
|
||||
in_channels: int = 3,
|
||||
out_ch: int = 3,
|
||||
z_channels: int = 16,
|
||||
temporal_compress_times: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.encoder = KVAECachedEncoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
in_channels=in_channels,
|
||||
z_channels=z_channels,
|
||||
double_z=True,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.decoder = KVAECachedDecoder3D(
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
out_ch=out_ch,
|
||||
z_channels=z_channels,
|
||||
temporal_compress_times=temporal_compress_times,
|
||||
)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
def _make_encoder_cache(self) -> Dict:
|
||||
"""Create empty cache for cached encoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_enc"),
|
||||
"mid_2": make_dict("resblock_enc"),
|
||||
"norm_out": make_dict("norm_enc"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
# Encoder uses num_res_blocks per level
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
|
||||
return cache
|
||||
|
||||
def _make_decoder_cache(self) -> Dict:
|
||||
"""Create empty cache for decoder."""
|
||||
|
||||
def make_dict(name, p=None):
|
||||
if name == "conv":
|
||||
return {"padding": None}
|
||||
|
||||
layer, module = name.split("_")
|
||||
if layer == "norm":
|
||||
if module == "enc":
|
||||
return {"mean": None, "var": None}
|
||||
else:
|
||||
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
|
||||
elif layer == "resblock":
|
||||
return {
|
||||
"norm1": make_dict(f"norm_{module}"),
|
||||
"norm2": make_dict(f"norm_{module}"),
|
||||
"conv1": make_dict("conv"),
|
||||
"conv2": make_dict("conv"),
|
||||
"conv_shortcut": make_dict("conv"),
|
||||
}
|
||||
elif layer.isdigit():
|
||||
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
|
||||
for i in range(p):
|
||||
out_dict[i] = make_dict(f"resblock_{module}")
|
||||
return out_dict
|
||||
|
||||
cache = {
|
||||
"conv_in": make_dict("conv"),
|
||||
"mid_1": make_dict("resblock_dec"),
|
||||
"mid_2": make_dict("resblock_dec"),
|
||||
"norm_out": make_dict("norm_dec"),
|
||||
"conv_out": make_dict("conv"),
|
||||
}
|
||||
for i in range(len(self.config.ch_mult)):
|
||||
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
|
||||
return cache
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""Enable sliced VAE decoding."""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""Disable sliced VAE decoding."""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
# Cached encoder processes by segments
|
||||
cache = self._make_encoder_cache()
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = x.size(2) - (seg_len + 1)
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
|
||||
latent = []
|
||||
for chunk in torch.split(x, split_list, dim=2):
|
||||
l = self.encoder(chunk, cache)
|
||||
sample, _ = torch.chunk(l, 2, dim=1)
|
||||
latent.append(sample)
|
||||
|
||||
return torch.cat(latent, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
||||
"""
|
||||
Encode a batch of videos into latents.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded videos.
|
||||
"""
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
# For cached encoder, we already did the split in _encode
|
||||
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
|
||||
posterior = DiagonalGaussianDistribution(h_double)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
|
||||
cache = self._make_decoder_cache()
|
||||
temporal_compress = self.config.temporal_compress_times
|
||||
|
||||
split_list = [seg_len + 1]
|
||||
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
|
||||
while n_frames > 0:
|
||||
split_list.append(seg_len)
|
||||
n_frames -= seg_len
|
||||
split_list[-1] += n_frames
|
||||
split_list = [math.ceil(size / temporal_compress) for size in split_list]
|
||||
|
||||
recs = []
|
||||
for chunk in torch.split(z, split_list, dim=2):
|
||||
out = self.decoder(chunk, cache)
|
||||
recs.append(out)
|
||||
|
||||
return torch.cat(recs, dim=2)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
"""
|
||||
Decode a batch of videos.
|
||||
|
||||
Args:
|
||||
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
|
||||
"""
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z).sample
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
|
||||
@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
if hasattr(self.language_model, "_get_initial_cache_position"):
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
|
||||
@@ -521,6 +521,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAE(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .deprecation_utils import deprecate
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -67,9 +68,11 @@ else:
|
||||
global_rng = random.Random()
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.warning(
|
||||
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
|
||||
deprecate(
|
||||
"diffusers.utils.testing_utils",
|
||||
"1.0.0",
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
|
||||
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
|
||||
)
|
||||
_required_peft_version = is_peft_available() and version.parse(
|
||||
version.parse(importlib.metadata.version("peft")).base_version
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,23 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
if is_torch_version("<", "2.7.0"):
|
||||
return cached
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"base_dim": 3,
|
||||
"z_dim": 4,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"latents_mean": torch.randn(4).numpy().tolist(),
|
||||
"latents_std": torch.randn(4).numpy().tolist(),
|
||||
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
|
||||
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
73
tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAE
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"channels": 32,
|
||||
"num_enc_blocks": 1,
|
||||
"num_dec_blocks": 1,
|
||||
"z_channels": 4,
|
||||
"double_z": True,
|
||||
"ch_mult": (1, 2),
|
||||
"sample_size": 32,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAEEncoder2D",
|
||||
"KVAEDecoder2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@@ -0,0 +1,118 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLKVAEVideo
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLKVAEVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_kvae_video_config(self):
|
||||
return {
|
||||
"ch": 32,
|
||||
"ch_mult": (1, 2),
|
||||
"num_res_blocks": 1,
|
||||
"in_channels": 3,
|
||||
"out_ch": 3,
|
||||
"z_channels": 4,
|
||||
"temporal_compress_times": 2,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": video}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 3, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_kvae_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"KVAECachedEncoder3D",
|
||||
"KVAECachedDecoder3D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_model_parallelism(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
|
||||
)
|
||||
def test_sharded_checkpoints_device_map(self):
|
||||
pass
|
||||
|
||||
def _run_nondeterministic(self, fn):
|
||||
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
|
||||
# temporarily relax the requirement for training tests that do backward passes.
|
||||
import torch
|
||||
|
||||
torch.use_deterministic_algorithms(False)
|
||||
try:
|
||||
fn()
|
||||
finally:
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
def test_training(self):
|
||||
self._run_nondeterministic(super().test_training)
|
||||
|
||||
def test_ema_training(self):
|
||||
self._run_nondeterministic(super().test_ema_training)
|
||||
|
||||
@unittest.skip(
|
||||
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
|
||||
"that is mutated during the first forward. On recomputation the cache is already populated, "
|
||||
"causing a different execution path and numerically different gradients. "
|
||||
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
|
||||
)
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_training(self):
|
||||
self._run_nondeterministic(super().test_layerwise_casting_training)
|
||||
@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# ensure that enable_lora_hotswap is called before loading the first adapter
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
msg = (
|
||||
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
|
||||
)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
|
||||
assert any(msg in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
|
||||
# check possibility to ignore the error/warning
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.add_adapter(lora_config)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING):
|
||||
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
|
||||
assert len(caplog.records) == 0
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
|
||||
# check that wrong argument value raises an error
|
||||
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
# check the error and log
|
||||
import logging
|
||||
|
||||
from diffusers.utils import logging as diffusers_logging
|
||||
|
||||
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
|
||||
target_modules0 = ["to_q"]
|
||||
target_modules1 = ["to_q", "to_k"]
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
diffusers_logging.enable_propagation()
|
||||
try:
|
||||
with pytest.raises(RuntimeError): # peft raises RuntimeError
|
||||
with caplog.at_level(logging.ERROR):
|
||||
self._check_model_hotswap(
|
||||
tmp_path,
|
||||
do_compile=True,
|
||||
rank0=8,
|
||||
rank1=8,
|
||||
target_modules0=target_modules0,
|
||||
target_modules1=target_modules1,
|
||||
)
|
||||
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
|
||||
finally:
|
||||
diffusers_logging.disable_propagation()
|
||||
|
||||
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
|
||||
@require_torch_version_greater("2.7.1")
|
||||
|
||||
@@ -22,6 +22,7 @@ import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_inference(self, cp_type):
|
||||
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
@@ -194,6 +200,11 @@ class ContextParallelTesterMixin:
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(reason="Context parallel may not support batch_size > 1")
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
|
||||
def test_context_parallel_batch_inputs(self, cp_type):
|
||||
self.test_context_parallel_inference(cp_type, batch_size=2)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
@@ -209,6 +220,11 @@ class ContextParallelTesterMixin:
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
if active_backend == AttentionBackendName.NATIVE:
|
||||
pytest.skip("Ring attention is not supported with the native attention backend.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
|
||||
@@ -41,7 +41,6 @@ from ..testing_utils import (
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
@@ -151,8 +150,7 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
height = width = 4
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
@@ -219,6 +217,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
@@ -412,10 +414,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
|
||||
"""BitsAndBytes + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
|
||||
@@ -13,48 +13,94 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Flux2Transformer2DModel, attention_backend
|
||||
from diffusers import Flux2Transformer2DModel
|
||||
from diffusers.models.transformers.transformer_flux2 import (
|
||||
Flux2KVAttnProcessor,
|
||||
Flux2KVCache,
|
||||
Flux2KVLayerCache,
|
||||
Flux2KVParallelSelfAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
@@ -82,8 +128,286 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
|
||||
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux2 Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
"""TorchAO + compile tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
|
||||
num_ref_tokens = 4
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -91,72 +415,210 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"timestep_guidance_channels": 256,
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
num_ref_tokens = self.num_ref_tokens
|
||||
|
||||
# TODO (Daniel, Sayak): We can remove this test.
|
||||
def test_flux2_consistency(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ref_hidden_states = randn_tensor(
|
||||
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
img_hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model = self.model_class(**init_dict)
|
||||
# state_dict = model.state_dict()
|
||||
# for key, param in state_dict.items():
|
||||
# print(f"{key} | {param.shape}")
|
||||
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
ref_t_coords = torch.arange(1)
|
||||
ref_h_coords = torch.arange(num_ref_tokens)
|
||||
ref_w_coords = torch.arange(1)
|
||||
ref_l_coords = torch.arange(1)
|
||||
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
|
||||
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
image_ids = torch.cat([ref_ids, image_ids], dim=1)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
|
||||
"""KV cache tests for Flux2 Transformer."""
|
||||
|
||||
def test_kv_layer_cache_store_and_get(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
k = torch.randn(1, 4, 2, 16)
|
||||
v = torch.randn(1, 4, 2, 16)
|
||||
cache.store(k, v)
|
||||
k_out, v_out = cache.get()
|
||||
assert torch.equal(k, k_out)
|
||||
assert torch.equal(v, v_out)
|
||||
|
||||
def test_kv_layer_cache_get_before_store_raises(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
try:
|
||||
cache.get()
|
||||
assert False, "Expected RuntimeError"
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def test_kv_layer_cache_clear(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.k_ref is None
|
||||
assert cache.v_ref is None
|
||||
|
||||
def test_kv_cache_structure(self):
|
||||
num_double = 3
|
||||
num_single = 2
|
||||
cache = Flux2KVCache(num_double, num_single)
|
||||
assert len(cache.double_block_caches) == num_double
|
||||
assert len(cache.single_block_caches) == num_single
|
||||
assert cache.num_ref_tokens == 0
|
||||
|
||||
for i in range(num_double):
|
||||
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
|
||||
for i in range(num_single):
|
||||
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
|
||||
|
||||
def test_kv_cache_clear(self):
|
||||
cache = Flux2KVCache(2, 1)
|
||||
cache.num_ref_tokens = 4
|
||||
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.num_ref_tokens == 0
|
||||
assert cache.get_double(0).k_ref is None
|
||||
|
||||
def _set_kv_attn_processors(self, model):
|
||||
for block in model.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in model.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_returns_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
self._set_kv_attn_processors(model)
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
assert output.kv_cache is not None
|
||||
assert isinstance(output.kv_cache, Flux2KVCache)
|
||||
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
|
||||
|
||||
for layer_cache in output.kv_cache.double_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
for layer_cache in output.kv_cache.single_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_output_shape(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
height, width = 4, 4
|
||||
output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
assert output.sample.shape == (1, height * width, 4)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
@torch.no_grad()
|
||||
def test_cached_mode_uses_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
height, width = 4, 4
|
||||
extract_output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
||||
# fmt: on
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
|
||||
cached_output = model(
|
||||
**cached_inputs,
|
||||
kv_cache=extract_output.kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)
|
||||
|
||||
flat_output = output.cpu().flatten()
|
||||
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
||||
assert cached_output.sample.shape == (1, height * width, 4)
|
||||
assert cached_output.kv_cache is None
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@torch.no_grad()
|
||||
def test_extract_return_dict_false(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
assert isinstance(output, tuple)
|
||||
assert len(output) == 2
|
||||
assert isinstance(output[1], Flux2KVCache)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
@torch.no_grad()
|
||||
def test_no_kv_cache_mode_returns_no_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
output = model(**base_config.get_dummy_inputs())
|
||||
|
||||
|
||||
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
assert output.kv_cache is None
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
@@ -77,8 +78,7 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = embedding_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
@@ -106,9 +106,10 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_infers_text_seq_len_from_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -122,7 +123,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert normalized_mask.sum().item() == 2 * batch_size
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
@@ -139,7 +140,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
assert normalized_mask2.sum().item() == 5 * batch_size
|
||||
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
@@ -149,9 +150,10 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
def test_non_contiguous_attention_mask(self, batch_size):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.get_dummy_inputs(batch_size=batch_size)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
@@ -284,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_linear(self):
|
||||
super().test_hotswapping_compiled_model_linear()
|
||||
|
||||
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
|
||||
def test_hotswapping_compiled_model_both_linear_and_other(self):
|
||||
super().test_hotswapping_compiled_model_both_linear_and_other()
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -32,6 +33,33 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
|
||||
assert str(warning.warning) == "This message is better!!!"
|
||||
assert "diffusers/tests/others/test_utils.py" in warning.filename
|
||||
|
||||
def test_deprecate_testing_utils_module(self):
|
||||
import diffusers.utils.testing_utils
|
||||
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
warnings.simplefilter("always")
|
||||
importlib.reload(diffusers.utils.testing_utils)
|
||||
|
||||
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
|
||||
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
|
||||
|
||||
messages = [str(w.message) for w in deprecation_warnings]
|
||||
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
|
||||
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
|
||||
)
|
||||
assert any(
|
||||
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
|
||||
for msg in messages
|
||||
), f"Expected deprecation message substring not found, got: {messages}"
|
||||
|
||||
|
||||
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
|
||||
class ExpectationsTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user