mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-15 08:05:38 +08:00
Compare commits
45 Commits
modular-cu
...
yiyi-kandi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
acabbc0033 | ||
|
|
b0e1b866fd | ||
|
|
3ffdf7f113 | ||
|
|
c8be08149e | ||
|
|
894aa98a27 | ||
|
|
31a1474378 | ||
|
|
600e9d6b87 | ||
|
|
56b90b10ef | ||
|
|
78a23b9dde | ||
|
|
b9a3be2a15 | ||
|
|
4450265bf7 | ||
|
|
cd3cc6156e | ||
|
|
e7b91ed787 | ||
|
|
28458d0caf | ||
|
|
8fd22c0c5d | ||
|
|
9b06afba6b | ||
|
|
327ab84d19 | ||
|
|
588c12ab98 | ||
|
|
6a0233eb7a | ||
|
|
b615d5cb13 | ||
|
|
d5dcd94500 | ||
|
|
7084106eaa | ||
|
|
d62dffcb21 | ||
|
|
0190e55641 | ||
|
|
f52f3b45b7 | ||
|
|
88a8eea096 | ||
|
|
235f0d5df8 | ||
|
|
4aa22f3abe | ||
|
|
04efb19b1a | ||
|
|
7af80e9ffc | ||
|
|
e3a3e9d1b6 | ||
|
|
149fd53df8 | ||
|
|
f35c279439 | ||
|
|
43bd1e81d2 | ||
|
|
45240a7317 | ||
|
|
07e11b270f | ||
|
|
70fa62baea | ||
|
|
22e14bdac8 | ||
|
|
723d149dc1 | ||
|
|
86b6c2b686 | ||
|
|
c8f3a36fba | ||
|
|
0bd738f52b | ||
|
|
a0cf07f7e0 | ||
|
|
7db6093c53 | ||
|
|
d53f848720 |
@@ -260,6 +260,7 @@ else:
|
||||
"VQModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
)
|
||||
@@ -622,6 +623,7 @@ else:
|
||||
"WanPipeline",
|
||||
"WanVACEPipeline",
|
||||
"WanVideoToVideoPipeline",
|
||||
"Kandinsky5T2VPipeline",
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
@@ -951,6 +953,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
VQModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
attention_backend,
|
||||
)
|
||||
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -1283,6 +1286,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanPipeline,
|
||||
WanVACEPipeline,
|
||||
WanVideoToVideoPipeline,
|
||||
Kandinsky5T2VPipeline,
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
|
||||
@@ -77,6 +77,7 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"KandinskyLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
"QwenImageLoraLoaderMixin",
|
||||
@@ -126,6 +127,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin
|
||||
)
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
@@ -3638,6 +3638,292 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`Kandinsky5Transformer3DModel`],
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted on the Hub.
|
||||
- A path to a *directory* containing the model weights.
|
||||
- A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Whether to use safetensors for loading.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
hotswap (`bool`, *optional*):
|
||||
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
# Load LoRA into transformer
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
Load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters.
|
||||
transformer (`Kandinsky5Transformer3DModel`):
|
||||
The transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata.
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer and text encoders.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError(
|
||||
"You must pass at least one of `transformer_lora_layers`"
|
||||
)
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
|
||||
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
pipeline.load_lora_weights("path/to/lora.safetensors")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
@@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead."
|
||||
deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -101,6 +101,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -200,6 +201,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
TransformerTemporalModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
Kandinsky5Transformer3DModel,
|
||||
)
|
||||
from .unets import (
|
||||
I2VGenXLUNet,
|
||||
|
||||
@@ -37,3 +37,4 @@ if is_torch_available():
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .transformer_wan import WanTransformer3DModel
|
||||
from .transformer_wan_vace import WanVACETransformer3DModel
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
|
||||
709
src/diffusers/models/transformers/transformer_kandinsky.py
Normal file
709
src/diffusers/models/transformers/transformer_kandinsky.py
Normal file
@@ -0,0 +1,709 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import BoolTensor, IntTensor, Tensor, nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
TimestepEmbedding,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def exist(item):
|
||||
return item is not None
|
||||
|
||||
|
||||
def freeze(model):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
||||
return model
|
||||
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=dim, dtype=torch.float32)
|
||||
/ dim
|
||||
)
|
||||
return freqs
|
||||
|
||||
|
||||
def fractal_flatten(x, rope, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
x = x.flatten(1, 2)
|
||||
rope = rope.flatten(1, 2)
|
||||
else:
|
||||
x = x.flatten(1, 3)
|
||||
rope = rope.flatten(1, 3)
|
||||
return x, rope
|
||||
|
||||
|
||||
def fractal_unflatten(x, shape, block_mask=False):
|
||||
if block_mask:
|
||||
pixel_size = 8
|
||||
x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
|
||||
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
|
||||
else:
|
||||
x = x.reshape(*shape, *x.shape[2:])
|
||||
return x
|
||||
|
||||
|
||||
def local_patching(x, shape, group_size, dim=0):
|
||||
batch_size, duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
g1,
|
||||
height // g2,
|
||||
g2,
|
||||
width // g3,
|
||||
g3,
|
||||
*x.shape[dim + 3 :],
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 2,
|
||||
dim + 4,
|
||||
dim + 1,
|
||||
dim + 3,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape)),
|
||||
)
|
||||
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
|
||||
return x
|
||||
|
||||
|
||||
def local_merge(x, shape, group_size, dim=0):
|
||||
batch_size, duration, height, width = shape
|
||||
g1, g2, g3 = group_size
|
||||
x = x.reshape(
|
||||
*x.shape[:dim],
|
||||
duration // g1,
|
||||
height // g2,
|
||||
width // g3,
|
||||
g1,
|
||||
g2,
|
||||
g3,
|
||||
*x.shape[dim + 2 :],
|
||||
)
|
||||
x = x.permute(
|
||||
*range(len(x.shape[:dim])),
|
||||
dim,
|
||||
dim + 3,
|
||||
dim + 1,
|
||||
dim + 4,
|
||||
dim + 2,
|
||||
dim + 5,
|
||||
*range(dim + 6, len(x.shape)),
|
||||
)
|
||||
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
|
||||
return x
|
||||
|
||||
|
||||
def nablaT_v2(
|
||||
q: Tensor,
|
||||
k: Tensor,
|
||||
sta: Tensor,
|
||||
thr: float = 0.9,
|
||||
):
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
from torch.nn.attention.flex_attention import BlockMask
|
||||
else:
|
||||
raise ValueError("Nabla attention is not supported with this version of PyTorch")
|
||||
|
||||
q = q.transpose(1, 2).contiguous()
|
||||
k = k.transpose(1, 2).contiguous()
|
||||
|
||||
# Map estimation
|
||||
B, h, S, D = q.shape
|
||||
s1 = S // 64
|
||||
qa = q.reshape(B, h, s1, 64, D).mean(-2)
|
||||
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
|
||||
map = qa @ ka
|
||||
|
||||
map = torch.softmax(map / math.sqrt(D), dim=-1)
|
||||
# Map binarization
|
||||
vals, inds = map.sort(-1)
|
||||
cvals = vals.cumsum_(-1)
|
||||
mask = (cvals >= 1 - thr).int()
|
||||
mask = mask.gather(-1, inds.argsort(-1))
|
||||
|
||||
mask = torch.logical_or(mask, sta)
|
||||
|
||||
# BlockMask creation
|
||||
kv_nb = mask.sum(-1).to(torch.int32)
|
||||
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
|
||||
return BlockMask.from_kv_blocks(
|
||||
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
|
||||
)
|
||||
|
||||
|
||||
class Kandinsky5TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer(
|
||||
"freqs", get_freqs(model_dim // 2, max_period), persistent=False
|
||||
)
|
||||
self.freqs = get_freqs(self.model_dim // 2, self.max_period)
|
||||
self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, time):
|
||||
args = torch.outer(time, self.freqs.to(device=time.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class Kandinsky5TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class Kandinsky5VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, duration, height, width, dim = x.shape
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration // self.patch_size[0],
|
||||
self.patch_size[0],
|
||||
height // self.patch_size[1],
|
||||
self.patch_size[1],
|
||||
width // self.patch_size[2],
|
||||
self.patch_size[2],
|
||||
dim,
|
||||
)
|
||||
.permute(0, 1, 3, 5, 2, 4, 6, 7)
|
||||
.flatten(4, 7)
|
||||
)
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Kandinsky5RoPE1D(nn.Module):
|
||||
def __init__(self, dim, max_pos=1024, max_period=10000.0):
|
||||
super().__init__()
|
||||
self.max_period = max_period
|
||||
self.dim = dim
|
||||
self.max_pos = max_pos
|
||||
freq = get_freqs(dim // 2, max_period)
|
||||
pos = torch.arange(max_pos, dtype=freq.dtype)
|
||||
self.register_buffer(f"args", torch.outer(pos, freq), persistent=False)
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, pos):
|
||||
args = self.args[pos]
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
class Kandinsky5RoPE3D(nn.Module):
|
||||
def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
|
||||
super().__init__()
|
||||
self.axes_dims = axes_dims
|
||||
self.max_pos = max_pos
|
||||
self.max_period = max_period
|
||||
|
||||
for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
|
||||
freq = get_freqs(axes_dim // 2, max_period)
|
||||
pos = torch.arange(ax_max_pos, dtype=freq.dtype)
|
||||
self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
|
||||
|
||||
@torch.autocast(device_type="cuda", enabled=False)
|
||||
def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
|
||||
batch_size, duration, height, width = shape
|
||||
args_t = self.args_0[pos[0]] / scale_factor[0]
|
||||
args_h = self.args_1[pos[1]] / scale_factor[1]
|
||||
args_w = self.args_2[pos[2]] / scale_factor[2]
|
||||
|
||||
args = torch.cat(
|
||||
[
|
||||
args_t.view(1, duration, 1, 1, -1).repeat(
|
||||
batch_size, 1, height, width, 1
|
||||
),
|
||||
args_h.view(1, 1, height, 1, -1).repeat(
|
||||
batch_size, duration, 1, width, 1
|
||||
),
|
||||
args_w.view(1, 1, 1, width, -1).repeat(
|
||||
batch_size, duration, height, 1, 1
|
||||
),
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
cosine = torch.cos(args)
|
||||
sine = torch.sin(args)
|
||||
rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
|
||||
rope = rope.view(*rope.shape[:-1], 2, 2)
|
||||
return rope.unsqueeze(-4)
|
||||
|
||||
|
||||
class Kandinsky5Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, num_params * model_dim)
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class Kandinsky5AttnProcessor:
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
|
||||
# query, key, value = self.get_qkv(x)
|
||||
query = attn.to_query(hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
key = attn.to_key(encoder_hidden_states)
|
||||
value = attn.to_value(encoder_hidden_states)
|
||||
|
||||
shape, cond_shape = query.shape[:-1], key.shape[:-1]
|
||||
query = query.reshape(*shape, attn.num_heads, -1)
|
||||
key = key.reshape(*cond_shape, attn.num_heads, -1)
|
||||
value = value.reshape(*cond_shape, attn.num_heads, -1)
|
||||
|
||||
else:
|
||||
key = attn.to_key(hidden_states)
|
||||
value = attn.to_value(hidden_states)
|
||||
|
||||
shape = query.shape[:-1]
|
||||
query = query.reshape(*shape, attn.num_heads, -1)
|
||||
key = key.reshape(*shape, attn.num_heads, -1)
|
||||
value = value.reshape(*shape, attn.num_heads, -1)
|
||||
|
||||
# query, key = self.norm_qk(query, key)
|
||||
query = attn.query_norm(query.float()).type_as(query)
|
||||
key = attn.key_norm(key.float()).type_as(key)
|
||||
|
||||
def apply_rotary(x, rope):
|
||||
x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
|
||||
x_out = (rope * x_).sum(dim=-1)
|
||||
return x_out.reshape(*x.shape).to(torch.bfloat16)
|
||||
|
||||
if rotary_emb is not None:
|
||||
query = apply_rotary(query, rotary_emb).type_as(query)
|
||||
key = apply_rotary(key, rotary_emb).type_as(key)
|
||||
|
||||
if sparse_params is not None:
|
||||
attn_mask = nablaT_v2(
|
||||
query,
|
||||
key,
|
||||
sparse_params["sta_mask"],
|
||||
thr=sparse_params["P"],
|
||||
)
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(-2, -1)
|
||||
|
||||
attn_out = attn.out_layer(hidden_states)
|
||||
return attn_out
|
||||
|
||||
|
||||
|
||||
class Kandinsky5Attention(nn.Module, AttentionModuleMixin):
|
||||
|
||||
_default_processor_cls = Kandinsky5AttnProcessor
|
||||
_available_processors = [
|
||||
Kandinsky5AttnProcessor,
|
||||
]
|
||||
def __init__(self, num_channels, head_dim, processor=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
|
||||
self.to_query = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_key = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.to_value = nn.Linear(num_channels, num_channels, bias=True)
|
||||
self.query_norm = nn.RMSNorm(head_dim)
|
||||
self.key_norm = nn.RMSNorm(head_dim)
|
||||
|
||||
self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
sparse_params: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
import inspect
|
||||
|
||||
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
|
||||
quiet_attn_parameters = {}
|
||||
unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
|
||||
)
|
||||
kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
|
||||
|
||||
return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs)
|
||||
|
||||
class Kandinsky5FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(dim, ff_dim, bias=False)
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = nn.Linear(ff_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
|
||||
class Kandinsky5OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2)
|
||||
self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.out_layer = nn.Linear(
|
||||
model_dim, math.prod(patch_size) * visual_dim, bias=True
|
||||
)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed):
|
||||
shift, scale = torch.chunk(
|
||||
self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
|
||||
visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed)
|
||||
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
batch_size, duration, height, width, _ = x.shape
|
||||
x = (
|
||||
x.view(
|
||||
batch_size,
|
||||
duration,
|
||||
height,
|
||||
width,
|
||||
-1,
|
||||
self.patch_size[0],
|
||||
self.patch_size[1],
|
||||
self.patch_size[2],
|
||||
)
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(1, 2)
|
||||
.flatten(2, 3)
|
||||
.flatten(3, 4)
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
class Kandinsky5TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, x, time_embed, rope):
|
||||
self_attn_params, ff_params = torch.chunk(
|
||||
self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1
|
||||
)
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
|
||||
out = self.self_attention(out, rotary_emb=rope)
|
||||
x = (x.float() + gate.float() * out.float()).type_as(x)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
|
||||
out = self.feed_forward(out)
|
||||
x = (x.float() + gate.float() * out.float()).type_as(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Kandinsky5TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim):
|
||||
super().__init__()
|
||||
self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
|
||||
|
||||
self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
|
||||
|
||||
self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
|
||||
self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(
|
||||
self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
|
||||
)
|
||||
|
||||
shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
|
||||
visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
|
||||
visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
|
||||
visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
|
||||
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5Transformer3DModel(
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
A 3D Diffusion Transformer model for video-like data.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=4,
|
||||
in_text_dim=3584,
|
||||
in_text_dim2=768,
|
||||
time_dim=512,
|
||||
out_visual_dim=4,
|
||||
patch_size=(1, 2, 2),
|
||||
model_dim=2048,
|
||||
ff_dim=5120,
|
||||
num_text_blocks=2,
|
||||
num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24),
|
||||
visual_cond=False,
|
||||
attention_type: str = "regular",
|
||||
attention_causal: bool = None,
|
||||
attention_local: bool = None,
|
||||
attention_glob: bool = None,
|
||||
attention_window: int = None,
|
||||
attention_P: float = None,
|
||||
attention_wT: int = None,
|
||||
attention_wW: int = None,
|
||||
attention_wH: int = None,
|
||||
attention_add_sta: bool = None,
|
||||
attention_method: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
head_dim = sum(axes_dims)
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_cond = visual_cond
|
||||
self.attention_type = attention_type
|
||||
|
||||
visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
|
||||
|
||||
# Initialize embeddings
|
||||
self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
|
||||
self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
|
||||
self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
|
||||
self.visual_embeddings = Kandinsky5VisualEmbeddings(
|
||||
visual_embed_dim, model_dim, patch_size
|
||||
)
|
||||
|
||||
# Initialize positional embeddings
|
||||
self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
|
||||
self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims)
|
||||
|
||||
# Initialize transformer blocks
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_text_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
|
||||
for _ in range(num_visual_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# Initialize output layer
|
||||
self.out_layer = Kandinsky5OutLayer(
|
||||
model_dim, time_dim, out_visual_dim, patch_size
|
||||
)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor, # x
|
||||
encoder_hidden_states: torch.FloatTensor, # text_embed
|
||||
timestep: Union[torch.Tensor, float, int], # time
|
||||
pooled_projections: torch.FloatTensor, # pooled_text_embed
|
||||
visual_rope_pos: Tuple[int, int, int],
|
||||
text_rope_pos: torch.LongTensor,
|
||||
scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
|
||||
sparse_params: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Transformer2DModelOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Forward pass of the Kandinsky5 3D Transformer.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): Input visual states
|
||||
encoder_hidden_states (`torch.FloatTensor`): Text embeddings
|
||||
timestep (`torch.Tensor` or `float` or `int`): Current timestep
|
||||
pooled_projections (`torch.FloatTensor`): Pooled text embeddings
|
||||
visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE
|
||||
text_rope_pos (`torch.LongTensor`): Position for text RoPE
|
||||
scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE
|
||||
sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention
|
||||
return_dict (`bool`, optional): Whether to return a dictionary
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`:
|
||||
The output of the transformer
|
||||
"""
|
||||
x = hidden_states
|
||||
text_embed = encoder_hidden_states
|
||||
time = timestep
|
||||
pooled_text_embed = pooled_projections
|
||||
|
||||
text_embed = self.text_embeddings(text_embed)
|
||||
time_embed = self.time_embeddings(time)
|
||||
time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
text_rope = self.text_rope_embeddings(text_rope_pos)
|
||||
text_rope = text_rope.unsqueeze(dim=0)
|
||||
|
||||
for text_transformer_block in self.text_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
text_embed = self._gradient_checkpointing_func(
|
||||
text_transformer_block, text_embed, time_embed, text_rope
|
||||
)
|
||||
else:
|
||||
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
|
||||
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_rope = self.visual_rope_embeddings(
|
||||
visual_shape, visual_rope_pos, scale_factor
|
||||
)
|
||||
to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
|
||||
visual_embed, visual_rope = fractal_flatten(
|
||||
visual_embed, visual_rope, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
|
||||
for visual_transformer_block in self.visual_transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
visual_embed = self._gradient_checkpointing_func(
|
||||
visual_transformer_block,
|
||||
visual_embed,
|
||||
text_embed,
|
||||
time_embed,
|
||||
visual_rope,
|
||||
sparse_params,
|
||||
)
|
||||
else:
|
||||
visual_embed = visual_transformer_block(
|
||||
visual_embed, text_embed, time_embed, visual_rope, sparse_params
|
||||
)
|
||||
|
||||
visual_embed = fractal_unflatten(
|
||||
visual_embed, visual_shape, block_mask=to_fractal
|
||||
)
|
||||
x = self.out_layer(visual_embed, text_embed, time_embed)
|
||||
|
||||
if not return_dict:
|
||||
return x
|
||||
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
@@ -382,6 +382,7 @@ else:
|
||||
"WuerstchenPriorPipeline",
|
||||
]
|
||||
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
|
||||
_import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
|
||||
@@ -787,6 +788,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline
|
||||
from .kandinsky5 import Kandinsky5T2VPipeline
|
||||
from .wuerstchen import (
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
|
||||
48
src/diffusers/pipelines/kandinsky5/__init__.py
Normal file
48
src/diffusers/pipelines/kandinsky5/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_kandinsky import Kandinsky5T2VPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
838
src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
Normal file
838
src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
Normal file
@@ -0,0 +1,838 @@
|
||||
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import KandinskyLoraLoaderMixin
|
||||
from ...models import AutoencoderKLHunyuanVideo
|
||||
from ...models.transformers import Kandinsky5Transformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import KandinskyPipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import Kandinsky5T2VPipeline
|
||||
>>> from diffusers.utils import export_to_video
|
||||
|
||||
>>> # Available models:
|
||||
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
|
||||
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
|
||||
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
|
||||
>>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
|
||||
|
||||
>>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
|
||||
>>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen."
|
||||
>>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=512,
|
||||
... width=768,
|
||||
... num_frames=121,
|
||||
... num_inference_steps=50,
|
||||
... guidance_scale=5.0,
|
||||
... ).frames[0]
|
||||
|
||||
>>> export_to_video(output, "output.mp4", fps=24, quality=9)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
"""Clean text using ftfy if available and unescape HTML entities."""
|
||||
if is_ftfy_available():
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
"""Normalize whitespace in text by replacing multiple spaces with single space."""
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
"""Apply both basic cleaning and whitespace normalization to prompts."""
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Kandinsky 5.0.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
transformer ([`Kandinsky5Transformer3DModel`]):
|
||||
Conditional Transformer to denoise the encoded video latents.
|
||||
vae ([`AutoencoderKLHunyuanVideo`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
|
||||
Frozen text-encoder (Qwen2.5-VL).
|
||||
tokenizer ([`AutoProcessor`]):
|
||||
Tokenizer for Qwen2.5-VL.
|
||||
text_encoder_2 ([`CLIPTextModel`]):
|
||||
Frozen CLIP text encoder.
|
||||
tokenizer_2 ([`CLIPTokenizer`]):
|
||||
Tokenizer for CLIP.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Kandinsky5Transformer3DModel,
|
||||
vae: AutoencoderKLHunyuanVideo,
|
||||
text_encoder: Qwen2_5_VLForConditionalGeneration,
|
||||
tokenizer: Qwen2VLProcessor,
|
||||
text_encoder_2: CLIPTextModel,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
|
||||
"Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
|
||||
"Describe the location of the video, main characters or objects and their action.",
|
||||
"Describe the dynamism of the video and presented actions.",
|
||||
"Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
|
||||
"Describe the visual effects, postprocessing and transitions if they are presented in the video.",
|
||||
"Pay attention to the order of key actions shown in the scene.<|im_end|>",
|
||||
"<|im_start|>user\n{}<|im_end|>"])
|
||||
self.prompt_template_encode_start_idx = 129
|
||||
|
||||
self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
@staticmethod
|
||||
def fast_sta_nabla(
|
||||
T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda"
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a sparse temporal attention (STA) mask for efficient video generation.
|
||||
|
||||
This method generates a mask that limits attention to nearby frames and spatial positions,
|
||||
reducing computational complexity for video generation.
|
||||
|
||||
Args:
|
||||
T (int): Number of temporal frames
|
||||
H (int): Height in latent space
|
||||
W (int): Width in latent space
|
||||
wT (int): Temporal attention window size
|
||||
wH (int): Height attention window size
|
||||
wW (int): Width attention window size
|
||||
device (str): Device to create tensor on
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
|
||||
"""
|
||||
l = torch.Tensor([T, H, W]).amax()
|
||||
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
|
||||
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||
sta_t, sta_h, sta_w = (
|
||||
mat[:T, :T].flatten(),
|
||||
mat[:H, :H].flatten(),
|
||||
mat[:W, :W].flatten(),
|
||||
)
|
||||
sta_t = sta_t <= wT // 2
|
||||
sta_h = sta_h <= wH // 2
|
||||
sta_w = sta_w <= wW // 2
|
||||
sta_hw = (
|
||||
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
|
||||
.reshape(H, H, W, W)
|
||||
.transpose(1, 2)
|
||||
.flatten()
|
||||
)
|
||||
sta = (
|
||||
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
|
||||
.reshape(T, T, H * W, H * W)
|
||||
.transpose(1, 2)
|
||||
)
|
||||
return sta.reshape(T * H * W, T * H * W)
|
||||
|
||||
def get_sparse_params(self, sample, device):
|
||||
"""
|
||||
Generate sparse attention parameters for the transformer based on sample dimensions.
|
||||
|
||||
This method computes the sparse attention configuration needed for efficient
|
||||
video processing in the transformer model.
|
||||
|
||||
Args:
|
||||
sample (torch.Tensor): Input sample tensor
|
||||
device (torch.device): Device to place tensors on
|
||||
|
||||
Returns:
|
||||
Dict: Dictionary containing sparse attention parameters
|
||||
"""
|
||||
assert self.transformer.config.patch_size[0] == 1
|
||||
B, T, H, W, _ = sample.shape
|
||||
T, H, W = (
|
||||
T // self.transformer.config.patch_size[0],
|
||||
H // self.transformer.config.patch_size[1],
|
||||
W // self.transformer.config.patch_size[2],
|
||||
)
|
||||
if self.transformer.config.attention_type == "nabla":
|
||||
sta_mask = self.fast_sta_nabla(
|
||||
T, H // 8, W // 8,
|
||||
self.transformer.config.attention_wT,
|
||||
self.transformer.config.attention_wH,
|
||||
self.transformer.config.attention_wW,
|
||||
device=device
|
||||
)
|
||||
|
||||
sparse_params = {
|
||||
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||
"attention_type": self.transformer.config.attention_type,
|
||||
"to_fractal": True,
|
||||
"P": self.transformer.config.attention_P,
|
||||
"wT": self.transformer.config.attention_wT,
|
||||
"wW": self.transformer.config.attention_wW,
|
||||
"wH": self.transformer.config.attention_wH,
|
||||
"add_sta": self.transformer.config.attention_add_sta,
|
||||
"visual_shape": (T, H, W),
|
||||
"method": self.transformer.config.attention_method,
|
||||
}
|
||||
else:
|
||||
sparse_params = None
|
||||
|
||||
return sparse_params
|
||||
|
||||
def _encode_prompt_qwen(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 256,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Encode prompt using Qwen2.5-VL text encoder.
|
||||
|
||||
This method processes the input prompt through the Qwen2.5-VL model to generate
|
||||
text embeddings suitable for video generation.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, List[str]]): Input prompt or list of prompts
|
||||
device (torch.device): Device to run encoding on
|
||||
num_videos_per_prompt (int): Number of videos to generate per prompt
|
||||
max_sequence_length (int): Maximum sequence length for tokenization
|
||||
dtype (torch.dtype): Data type for embeddings
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
full_texts = [self.prompt_template.format(p) for p in prompt]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
text=full_texts,
|
||||
images=None,
|
||||
videos=None,
|
||||
max_length=max_sequence_length + self.prompt_template_encode_start_idx,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
).to(device)
|
||||
|
||||
embeds = self.text_encoder(
|
||||
input_ids=inputs["input_ids"],
|
||||
return_dict=True,
|
||||
output_hidden_states=True,
|
||||
)["hidden_states"][-1][:, self.prompt_template_encode_start_idx:]
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:]
|
||||
cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
|
||||
embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0)
|
||||
|
||||
return embeds.to(dtype), cu_seqlens
|
||||
|
||||
def _encode_prompt_clip(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""
|
||||
Encode prompt using CLIP text encoder.
|
||||
|
||||
This method processes the input prompt through the CLIP model to generate
|
||||
pooled embeddings that capture semantic information.
|
||||
|
||||
Args:
|
||||
prompt (Union[str, List[str]]): Input prompt or list of prompts
|
||||
device (torch.device): Device to run encoding on
|
||||
num_videos_per_prompt (int): Number of videos to generate per prompt
|
||||
dtype (torch.dtype): Data type for embeddings
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Pooled text embeddings from CLIP
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder_2.dtype
|
||||
|
||||
inputs = self.tokenizer_2(
|
||||
prompt,
|
||||
max_length=77,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
padding="max_length",
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
|
||||
|
||||
batch_size = len(prompt)
|
||||
pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1)
|
||||
pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1)
|
||||
|
||||
return pooled_embed.to(dtype)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 512,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes a single prompt (positive or negative) into text encoder hidden states.
|
||||
|
||||
This method combines embeddings from both Qwen2.5-VL and CLIP text encoders
|
||||
to create comprehensive text representations for video generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
Prompt to be encoded.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for text encoding.
|
||||
device (`torch.device`, *optional*):
|
||||
Torch device.
|
||||
dtype (`torch.dtype`, *optional*):
|
||||
Torch dtype.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.Tensor], torch.Tensor]:
|
||||
- A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP)
|
||||
- Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt = [prompt_clean(p) for p in prompt]
|
||||
|
||||
# Encode with Qwen2.5-VL
|
||||
prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Encode with CLIP
|
||||
prompt_embeds_clip = self._encode_prompt_clip(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
prompt_embeds_dict = {
|
||||
"text_embeds": prompt_embeds_qwen,
|
||||
"pooled_embed": prompt_embeds_clip,
|
||||
}
|
||||
|
||||
return prompt_embeds_dict, prompt_cu_seqlens
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
"""
|
||||
Validate input parameters for the pipeline.
|
||||
|
||||
Args:
|
||||
prompt: Input prompt
|
||||
negative_prompt: Negative prompt for guidance
|
||||
height: Video height
|
||||
width: Video width
|
||||
prompt_embeds: Pre-computed prompt embeddings
|
||||
negative_prompt_embeds: Pre-computed negative prompt embeddings
|
||||
callback_on_step_end_tensor_inputs: Callback tensor inputs
|
||||
|
||||
Raises:
|
||||
ValueError: If inputs are invalid
|
||||
"""
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Prepare initial latent variables for video generation.
|
||||
|
||||
This method creates random noise latents or uses provided latents as starting point
|
||||
for the denoising process.
|
||||
|
||||
Args:
|
||||
batch_size (int): Number of videos to generate
|
||||
num_channels_latents (int): Number of channels in latent space
|
||||
height (int): Height of generated video
|
||||
width (int): Width of generated video
|
||||
num_frames (int): Number of frames in video
|
||||
dtype (torch.dtype): Data type for latents
|
||||
device (torch.device): Device to create latents on
|
||||
generator (torch.Generator): Random number generator
|
||||
latents (torch.Tensor): Pre-existing latents to use
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Prepared latent tensor
|
||||
"""
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if self.transformer.visual_cond:
|
||||
# For visual conditioning, concatenate with zeros and mask
|
||||
visual_cond = torch.zeros_like(latents)
|
||||
visual_cond_mask = torch.zeros(
|
||||
[batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1],
|
||||
dtype=latents.dtype,
|
||||
device=latents.device
|
||||
)
|
||||
latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
"""Get the current guidance scale value."""
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
"""Check if classifier-free guidance is enabled."""
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
"""Get the number of denoising timesteps."""
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
"""Check if generation has been interrupted."""
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
scheduler_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (`guidance_scale` < `1`).
|
||||
height (`int`, defaults to `512`):
|
||||
The height in pixels of the generated video.
|
||||
width (`int`, defaults to `768`):
|
||||
The width in pixels of the generated video.
|
||||
num_frames (`int`, defaults to `25`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in classifier-free guidance.
|
||||
scheduler_scale (`float`, defaults to `10.0`):
|
||||
Scale factor for the custom flow matching scheduler.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of videos to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A torch generator to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated video.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`KandinskyPipelineOutput`].
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function that is called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function.
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum sequence length for text encoding.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~KandinskyPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images.
|
||||
"""
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
if num_frames % self.vae_scale_factor_temporal != 1:
|
||||
logger.warning(
|
||||
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
||||
)
|
||||
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
||||
num_frames = max(num_frames, 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
dtype = self.transformer.dtype
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
prompt = [prompt]
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
negative_prompt_embeds_dict = None
|
||||
negative_cu_seqlens = None
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
|
||||
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
|
||||
elif len(negative_prompt) != len(prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
|
||||
)
|
||||
|
||||
negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_visual_dim
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare rope positions for positional encoding
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
visual_rope_pos = [
|
||||
torch.arange(num_latent_frames, device=device),
|
||||
torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
|
||||
torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
|
||||
]
|
||||
|
||||
text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
|
||||
|
||||
negative_text_rope_pos = (
|
||||
torch.arange(negative_cu_seqlens.diff().max().item(), device=device)
|
||||
if negative_cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
|
||||
# 7. Sparse Params for efficient attention
|
||||
sparse_params = self.get_sparse_params(latents, device)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
# Predict noise residual
|
||||
pred_velocity = self.transformer(
|
||||
hidden_states=latents.to(dtype),
|
||||
encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype),
|
||||
pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype),
|
||||
timestep=timestep.to(dtype),
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=sparse_params,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None:
|
||||
uncond_pred_velocity = self.transformer(
|
||||
hidden_states=latents.to(dtype),
|
||||
encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype),
|
||||
pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype),
|
||||
timestep=timestep.to(dtype),
|
||||
visual_rope_pos=visual_rope_pos,
|
||||
text_rope_pos=negative_text_rope_pos,
|
||||
scale_factor=(1, 2, 2),
|
||||
sparse_params=sparse_params,
|
||||
return_dict=True
|
||||
).sample
|
||||
|
||||
pred_velocity = uncond_pred_velocity + guidance_scale * (
|
||||
pred_velocity - uncond_pred_velocity
|
||||
)
|
||||
|
||||
# Compute previous sample using the scheduler
|
||||
latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
|
||||
pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict)
|
||||
negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# 8. Post-processing - extract main latents
|
||||
latents = latents[:, :, :, :, :num_channels_latents]
|
||||
|
||||
# 9. Decode latents to video
|
||||
if output_type != "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
# Reshape and normalize latents
|
||||
video = latents.reshape(
|
||||
batch_size,
|
||||
num_videos_per_prompt,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial,
|
||||
num_channels_latents,
|
||||
)
|
||||
video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
|
||||
video = video.reshape(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
(num_frames - 1) // self.vae_scale_factor_temporal + 1,
|
||||
height // self.vae_scale_factor_spatial,
|
||||
width // self.vae_scale_factor_spatial
|
||||
)
|
||||
|
||||
# Normalize and decode through VAE
|
||||
video = video / self.vae.config.scaling_factor
|
||||
video = self.vae.decode(video).sample
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return KandinskyPipelineOutput(frames=video)
|
||||
20
src/diffusers/pipelines/kandinsky5/pipeline_output.py
Normal file
20
src/diffusers/pipelines/kandinsky5/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class KandinskyPipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for Wan pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
Reference in New Issue
Block a user