Compare commits

..

1 Commits

Author SHA1 Message Date
yiyixuxu
4a56b5083b up 2026-02-10 04:57:17 +01:00
52 changed files with 796 additions and 1016 deletions

View File

@@ -29,31 +29,8 @@ text_encoder = AutoModel.from_pretrained(
)
```
## Custom models
[`AutoModel`] also loads models from the [Hub](https://huggingface.co/models) that aren't included in Diffusers. Set `trust_remote_code=True` in [`AutoModel.from_pretrained`] to load custom models.
A custom model repository needs a Python module with the model class, and a `config.json` with an `auto_map` entry that maps `"AutoModel"` to `"module_file.ClassName"`.
```
custom/custom-transformer-model/
├── config.json
├── my_model.py
└── diffusion_pytorch_model.safetensors
```
The `config.json` includes the `auto_map` field pointing to the custom class.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
Then load it with `trust_remote_code=True`.
```py
import torch
from diffusers import AutoModel
@@ -63,39 +40,7 @@ transformer = AutoModel.from_pretrained(
)
```
For a real-world example, [Overworld/Waypoint-1-Small](https://huggingface.co/Overworld/Waypoint-1-Small/tree/main/transformer) hosts a custom `WorldModel` class across several modules in its `transformer` subfolder.
```
transformer/
├── config.json # auto_map: "model.WorldModel"
├── model.py
├── attn.py
├── nn.py
├── cache.py
├── quantize.py
├── __init__.py
└── diffusion_pytorch_model.safetensors
```
```py
import torch
from diffusers import AutoModel
transformer = AutoModel.from_pretrained(
"Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="cuda"
)
```
If the custom model inherits from the [`ModelMixin`] class, it gets access to the same features as Diffusers model classes, like [regional compilation](../optimization/fp16#regional-compilation) and [group offloading](../optimization/memory#group-offloading).
> [!WARNING]
> As a precaution with `trust_remote_code=True`, pass a commit hash to the `revision` argument in [`AutoModel.from_pretrained`] to make sure the code hasn't been updated with new malicious code (unless you fully trust the model owners).
>
> ```py
> transformer = AutoModel.from_pretrained(
> "Overworld/Waypoint-1-Small", subfolder="transformer", trust_remote_code=True, revision="a3d8cb2"
> )
> ```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -297,6 +297,8 @@ else:
"ComponentSpec",
"ModularPipeline",
"ModularPipelineBlocks",
"InputParam",
"OutputParam",
]
)
_import_structure["optimization"] = [
@@ -1060,7 +1062,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageTransformer2DModel,
attention_backend,
)
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks
from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks, InputParam, OutputParam
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,

View File

@@ -2321,14 +2321,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
prefix = "diffusion_model."
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
if has_lora_down_up:
temp_state_dict = {}
for k, v in original_state_dict.items():
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
temp_state_dict[new_key] = v
original_state_dict = temp_state_dict
num_double_layers = 0
num_single_layers = 0
for key in original_state_dict.keys():
@@ -2345,15 +2337,13 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
attn_prefix = f"single_transformer_blocks.{sl}.attn"
for lora_key in lora_keys:
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
if linear1_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
linear1_key
)
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear1.{lora_key}.weight"
)
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
if linear2_key in original_state_dict:
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
f"{single_block_prefix}.linear2.{lora_key}.weight"
)
for dl in range(num_double_layers):
transformer_block_prefix = f"transformer_blocks.{dl}"
@@ -2362,10 +2352,6 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for attn_type in attn_types:
attn_prefix = f"{transformer_block_prefix}.attn"
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
if qkv_key not in original_state_dict:
continue
fused_qkv_weight = original_state_dict.pop(qkv_key)
if lora_key == "lora_A":
@@ -2397,9 +2383,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_proj, diff_proj in proj_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
mlp_mappings = [
("img_mlp.0", "ff.linear_in"),
@@ -2410,27 +2395,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
for org_mlp, diff_mlp in mlp_mappings:
for lora_key in lora_keys:
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
if original_key in original_state_dict:
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
extra_mappings = {
"img_in": "x_embedder",
"txt_in": "context_embedder",
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
"final_layer.linear": "proj_out",
"final_layer.adaLN_modulation.1": "norm_out.linear",
"single_stream_modulation.lin": "single_stream_modulation.linear",
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
}
for org_key, diff_key in extra_mappings.items():
for lora_key in lora_keys:
original_key = f"{org_key}.{lora_key}.weight"
if original_key in original_state_dict:
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
if len(original_state_dict) > 0:
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

View File

@@ -21,7 +21,7 @@ from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ...utils import BaseOutput, logging
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -598,7 +598,6 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -20,11 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import (
BaseOutput,
apply_lora_scale,
logging,
)
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
@@ -154,7 +150,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -202,6 +197,20 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
@@ -314,6 +323,10 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)

View File

@@ -20,12 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -128,7 +123,6 @@ class QwenImageControlNetModel(
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -187,6 +181,20 @@ class QwenImageControlNetModel(
standard_warn=False,
)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
# add
@@ -248,6 +256,10 @@ class QwenImageControlNetModel(
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return controlnet_block_samples

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -117,7 +117,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -130,6 +129,21 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -204,6 +218,10 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, JointTransformerBlock
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -269,7 +269,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -309,6 +308,21 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")
@@ -368,6 +382,10 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_res_samples,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -397,7 +397,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -406,6 +405,21 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -472,6 +486,10 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -363,7 +363,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -375,6 +374,21 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
@@ -440,6 +454,10 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0
@@ -620,7 +620,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
]
)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -633,6 +632,21 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# fuse clip and insightface
valid_face_emb = None
if self.is_train_face:
@@ -706,6 +720,10 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..attention_processor import (
Attention,
@@ -414,7 +414,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -427,6 +426,21 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -513,6 +527,10 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -581,7 +581,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -622,6 +621,20 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -702,6 +715,10 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -22,8 +22,10 @@ from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
apply_lora_scale,
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
@@ -508,7 +510,6 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
]
self.caption_projection = nn.ModuleList(caption_projection)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -544,7 +545,20 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -631,6 +645,10 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
@@ -473,7 +473,6 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -512,6 +511,20 @@ class ChromaTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -618,6 +631,10 @@ class ChromaTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -219,10 +219,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -230,7 +227,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
@@ -641,7 +638,6 @@ class ChronoEditTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -651,6 +647,21 @@ class ChronoEditTransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -718,6 +729,10 @@ class ChronoEditTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -703,7 +703,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -719,6 +718,21 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
@@ -765,6 +779,10 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,7 +634,6 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -676,6 +675,20 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -772,6 +785,10 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -774,7 +774,6 @@ class Flux2Transformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -811,6 +810,20 @@ class Flux2Transformer2DModel(
`tuple` where the first element is the sample tensor.
"""
# 0. Handle input arguments
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
num_txt_tokens = encoder_hidden_states.shape[1]
@@ -895,6 +908,10 @@ class Flux2Transformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
@@ -773,7 +773,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
return hidden_states, hidden_states_masks, img_sizes, img_ids
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -809,6 +808,21 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
@@ -919,6 +933,10 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -989,7 +989,6 @@ class HunyuanVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1001,6 +1000,21 @@ class HunyuanVideoTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -1090,6 +1104,10 @@ class HunyuanVideoTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -620,7 +620,6 @@ class HunyuanVideo15Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -634,6 +633,21 @@ class HunyuanVideo15Transformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -769,6 +783,10 @@ class HunyuanVideo15Transformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, get_logger
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
@@ -198,7 +198,6 @@ class HunyuanVideoFramepackTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -218,6 +217,21 @@ class HunyuanVideoFramepackTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -323,6 +337,10 @@ class HunyuanVideoFramepackTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)

View File

@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -742,7 +742,6 @@ class HunyuanImageTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -756,6 +755,21 @@ class HunyuanImageTransformer2DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
@@ -886,6 +900,10 @@ class HunyuanImageTransformer2DModel(
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, is_torch_version, logging
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -491,7 +491,6 @@ class LTXVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -506,6 +505,21 @@ class LTXVideoTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
@@ -554,6 +568,10 @@ class LTXVideoTransformer3DModel(
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,14 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -1094,7 +1101,6 @@ class LTX2VideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1165,6 +1171,21 @@ class LTX2VideoTransformer3DModel(
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
element is the denoised audio latent patch sequence.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -1320,6 +1341,10 @@ class LTX2VideoTransformer3DModel(
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
audio_output = self.audio_proj_out(audio_hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, audio_output)
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -455,7 +455,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -465,6 +464,21 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# 1. Condition, positional & patch embedding
batch_size, _, height, width = hidden_states.shape
@@ -525,6 +539,10 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
)
output = torch.stack(output, dim=0)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
@@ -404,7 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -414,6 +413,21 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
@@ -465,6 +479,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
@@ -829,7 +829,6 @@ class QwenImageTransformer2DModel(
self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -888,6 +887,20 @@ class QwenImageTransformer2DModel(
"The mask-based approach is more flexible and supports variable-length sequences.",
standard_warn=False,
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
@@ -968,6 +981,10 @@ class QwenImageTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -570,7 +570,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -583,6 +582,21 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -681,6 +695,10 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -18,7 +18,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
from ..attention_processor import (
@@ -245,7 +245,6 @@ class SD3Transformer2DModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -285,6 +284,20 @@ class SD3Transformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
@@ -339,6 +352,10 @@ class SD3Transformer2DModel(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -630,7 +630,6 @@ class SkyReelsV2Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -642,6 +641,21 @@ class SkyReelsV2Transformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -757,6 +771,10 @@ class SkyReelsV2Transformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -214,10 +214,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -225,7 +222,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
@@ -625,7 +622,6 @@ class WanTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -635,6 +631,21 @@ class WanTransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -702,6 +713,10 @@ class WanTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
if attn.cross_attention_dim_head is None:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
@@ -502,16 +502,13 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
dim_head: int = 64,
eps: float = 1e-6,
cross_attention_dim_head: Optional[int] = None,
bias: bool = True,
processor=None,
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.cross_attention_dim_head = cross_attention_dim_head
self.cross_attention_head_dim = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.use_bias = bias
self.is_cross_attention = cross_attention_dim_head is not None
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
# NOTE: this is not used in "vanilla" WanAttention
@@ -519,10 +516,10 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
# 2. QKV and Output Projections
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
# 3. QK Norm
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
@@ -685,10 +682,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
@@ -696,7 +690,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
if self.cross_attention_dim_head is None:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
@@ -1147,7 +1141,6 @@ class WanAnimateTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1186,6 +1179,21 @@ class WanAnimateTransformer3DModel(
Whether to return the output as a dict or tuple.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Check that shapes match up
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
raise ValueError(
@@ -1286,6 +1294,10 @@ class WanAnimateTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
@@ -76,7 +76,6 @@ class WanVACETransformerBlock(nn.Module):
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
processor=WanAttnProcessor(),
is_cross_attention=True,
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -179,7 +178,6 @@ class WanVACETransformer3DModel(
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
@register_to_config
def __init__(
@@ -263,7 +261,6 @@ class WanVACETransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -275,6 +272,21 @@ class WanVACETransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -367,6 +379,10 @@ class WanVACETransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,12 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -979,7 +974,6 @@ class UNet2DConditionModel(
encoder_hidden_states = (encoder_hidden_states, image_embeds)
return encoder_hidden_states
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,
@@ -1118,6 +1112,18 @@ class UNet2DConditionModel(
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
@@ -1233,6 +1239,10 @@ class UNet2DConditionModel(
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, apply_lora_scale, deprecate, logging
from ...utils import BaseOutput, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import (
@@ -1875,7 +1875,6 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -21,7 +21,6 @@ from torch.utils.checkpoint import checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale
from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -147,7 +146,6 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
@apply_lora_scale("cross_attention_kwargs")
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)

View File

@@ -31,9 +31,7 @@ else:
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
@@ -55,9 +53,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline

View File

@@ -201,37 +201,6 @@ class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# denoise: Flux Kontext
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents for Flux Kontext. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
@@ -322,7 +291,7 @@ class FluxKontextAutoInputStep(AutoPipelineBlocks):
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -331,7 +300,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
@@ -340,7 +309,7 @@ class FluxCoreDenoiseStep(SequentialPipelineBlocks):
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
@@ -349,7 +318,7 @@ class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
"Core step that performs the denoising process. \n"
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ " - `FluxKontextDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."

View File

@@ -18,6 +18,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -18,6 +18,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import ftfy
import PIL
import torch
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -19,6 +19,7 @@ import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import torch
from PIL import Image
from transformers import AutoTokenizer, UMT5EncoderModel

View File

@@ -41,7 +41,7 @@ class GGUFQuantizer(DiffusersQuantizer):
self.compute_dtype = quantization_config.compute_dtype
self.pre_quantized = quantization_config.pre_quantized
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
self.modules_to_not_convert = quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]

View File

@@ -131,7 +131,6 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -16,7 +16,6 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from typing import Optional
@@ -276,59 +275,6 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.

View File

@@ -446,17 +446,16 @@ class ModelTesterMixin:
torch_device not in ["cuda", "xpu"],
reason="float16 and bfloat16 can only be used with an accelerator",
)
def test_keep_in_fp32_modules(self, tmp_path):
def test_keep_in_fp32_modules(self):
model = self.model_class(**self.get_init_dict())
fp32_modules = model._keep_in_fp32_modules
if fp32_modules is None or len(fp32_modules) == 0:
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
# Save the model and reload with float16 dtype
# _keep_in_fp32_modules is only enforced during from_pretrained loading
model.save_pretrained(tmp_path)
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
# Test with float16
model.to(torch_device)
model.to(torch.float16)
for name, param in model.named_parameters():
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
@@ -471,7 +470,7 @@ class ModelTesterMixin:
)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
@torch.no_grad()
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
fp32_modules = model._keep_in_fp32_modules or []
@@ -491,6 +490,10 @@ class ModelTesterMixin:
output = model(**inputs, return_dict=False)[0]
output_loaded = model_loaded(**inputs, return_dict=False)[0]
self._check_dtype_inference_output(output, output_loaded, dtype)
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
"""Check dtype inference output with configurable tolerance."""
assert_tensors_close(
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
)

View File

@@ -176,7 +176,15 @@ class QuantizationTesterMixin:
model_quantized = self._create_quantized_model(config_kwargs)
model_quantized.to(torch_device)
# Get model dtype from first parameter
model_dtype = next(model_quantized.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model_quantized(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None"
@@ -221,8 +229,6 @@ class QuantizationTesterMixin:
init_lora_weights=False,
)
model.add_adapter(lora_config)
# Move LoRA adapter weights to device (they default to CPU)
model.to(torch_device)
inputs = self.get_dummy_inputs()
output = model(**inputs, return_dict=False)[0]
@@ -1015,6 +1021,9 @@ class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
"""Test that dequantize() works correctly."""
self._test_dequantize({"compute_dtype": torch.bfloat16})
def test_gguf_quantized_layers(self):
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
@is_quantization
@is_modelopt

View File

@@ -12,57 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import unittest
import torch
from diffusers import WanTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanTransformer3DModel
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan22-transformer"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -76,160 +76,16 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Transformer 3D."""
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
return super()._create_quantized_model(
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
)
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.
Wan 2.2 I2V: in_channels=36, text_dim=4096
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -12,62 +12,76 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import unittest
import torch
from diffusers import WanAnimateTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanAnimateTransformer3DModel
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-animate-transformer"
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size` below
face_width = 16
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
torch_device
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_hidden_states_image": clip_ref_features,
"pose_hidden_states": pose_latents,
"face_pixel_values": face_pixel_values,
}
@property
def output_shape(self) -> tuple[int, ...]:
# Output has fewer channels than input (4 vs 12)
return (4, 21, 16, 16)
def input_shape(self):
return (12, 1, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (12, 21, 16, 16)
def output_shape(self):
return (4, 1, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
def prepare_init_args_and_inputs_for_common(self):
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
# contain the vast majority of the parameters in the test model
channel_sizes = {"4": 16, "8": 16, "16": 16}
return {
init_dict = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
@@ -91,219 +105,22 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
"face_encoder_num_heads": 2,
"inject_face_latents_blocks": 2,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
clip_seq_len = 12
clip_dim = 16
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
face_height = 16 # Should be square and match `motion_encoder_size`
face_width = 16
return {
"hidden_states": randn_tensor(
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states_image": randn_tensor(
(batch_size, clip_seq_len, clip_dim),
generator=self.generator,
device=torch_device,
),
"pose_hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"face_pixel_values": randn_tensor(
(batch_size, 3, inference_segment_length, face_height, face_width),
generator=self.generator,
device=torch_device,
),
}
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Animate Transformer 3D."""
def test_output(self):
# Override test_output because the transformer output is expected to have less channels
# than the main transformer input.
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Animate Transformer 3D."""
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Animate Transformer 3D."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanAnimateTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Animate Transformer 3D."""
# Override test_output because the transformer output is expected to have less channels than the main transformer
# input.
def test_output(self):
expected_output_shape = (1, 4, 21, 16, 16)
super().test_output(expected_output_shape=expected_output_shape)
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Animate Transformer 3D."""
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanAnimateTransformer3DModel
def test_torch_compile_recompilation_and_graph_break(self):
# Skip: F.pad with mode="replicate" in WanAnimateFaceEncoder triggers importlib.import_module
# internally, which dynamo doesn't support tracing through.
pytest.skip("F.pad with replicate mode triggers unsupported import in torch.compile")
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Animate Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan Animate model dimensions.
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"pose_hidden_states": randn_tensor(
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"face_pixel_values": randn_tensor(
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
def prepare_init_args_and_inputs_for_common(self):
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()

View File

@@ -1,271 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers import WanVACETransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return WanVACETransformer3DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-wan-vace-transformer"
@property
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 16,
"out_channels": 16,
"text_dim": 32,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 4,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
"vace_layers": [0, 2],
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 32
sequence_length = 12
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
vace_in_channels = 48
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"control_hidden_states": randn_tensor(
(batch_size, vace_in_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan VACE Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
def test_model_parallelism(self, tmp_path):
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
pytest.skip("Model parallelism not yet supported for WanVACE")
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan VACE Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanVACETransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan VACE Transformer 3D."""
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan VACE Transformer 3D."""
def test_torch_compile_repeated_blocks(self):
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
# so we need recompile_limit=2 instead of the default 1.
import torch._dynamo
import torch._inductor.utils
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
model.compile_repeated_blocks(fullgraph=True)
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=2),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.float16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan VACE Transformer 3D."""
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan VACE model dimensions.
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
"""
return {
"hidden_states": randn_tensor(
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"control_hidden_states": randn_tensor(
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}

View File

@@ -168,7 +168,7 @@ def assert_tensors_close(
max_diff = abs_diff.max().item()
flat_idx = abs_diff.argmax().item()
max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
threshold = atol + rtol * expected.abs()
mismatched = (abs_diff > threshold).sum().item()