mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-04 01:45:15 +08:00
Compare commits
1 Commits
apply-lora
...
flux-test-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
700f882bc5 |
@@ -21,7 +21,7 @@ from torch.nn import functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -598,7 +598,6 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -20,11 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
@@ -154,7 +150,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -202,6 +197,20 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
@@ -314,6 +323,10 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
|
||||
@@ -20,12 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -128,7 +123,6 @@ class QwenImageControlNetModel(
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -187,6 +181,20 @@ class QwenImageControlNetModel(
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
# add
|
||||
@@ -248,6 +256,10 @@ class QwenImageControlNetModel(
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return controlnet_block_samples
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -117,7 +117,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -130,6 +129,21 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -204,6 +218,10 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, JointTransformerBlock
|
||||
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
@@ -269,7 +269,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -309,6 +308,21 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if self.pos_embed is not None and hidden_states.ndim != 4:
|
||||
raise ValueError("hidden_states must be 4D when pos_embed is used")
|
||||
|
||||
@@ -368,6 +382,10 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
# 6. scaling
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
@@ -397,7 +397,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -406,6 +405,21 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
||||
@@ -472,6 +486,10 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin, FeedForward
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
@@ -363,7 +363,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -375,6 +374,21 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
@@ -440,6 +454,10 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin, FeedForward
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0
|
||||
@@ -620,7 +620,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
]
|
||||
)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -633,6 +632,21 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# fuse clip and insightface
|
||||
valid_face_emb = None
|
||||
if self.is_train_face:
|
||||
@@ -706,6 +720,10 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -414,7 +414,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -427,6 +426,21 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -513,6 +527,10 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
||||
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -581,7 +581,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -622,6 +621,20 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
@@ -702,6 +715,10 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@ from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...utils import (
|
||||
apply_lora_scale,
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
@@ -508,7 +510,6 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
]
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -544,7 +545,20 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
@@ -631,6 +645,10 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -473,7 +473,6 @@ class ChromaTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -512,6 +511,20 @@ class ChromaTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -618,6 +631,10 @@ class ChromaTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -638,7 +638,6 @@ class ChronoEditTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -648,6 +647,21 @@ class ChronoEditTransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -715,6 +729,10 @@ class ChronoEditTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -703,7 +703,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -719,6 +718,21 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
@@ -765,6 +779,10 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -634,7 +634,6 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -676,6 +675,20 @@ class FluxTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -772,6 +785,10 @@ class FluxTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -774,7 +774,6 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -811,6 +810,20 @@ class Flux2Transformer2DModel(
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
@@ -895,6 +908,10 @@ class Flux2Transformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -773,7 +773,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
return hidden_states, hidden_states_masks, img_sizes, img_ids
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -809,6 +808,21 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||
)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
@@ -919,6 +933,10 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
if hidden_states_masks is not None:
|
||||
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -989,7 +989,6 @@ class HunyuanVideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1001,6 +1000,21 @@ class HunyuanVideoTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -1090,6 +1104,10 @@ class HunyuanVideoTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -620,7 +620,6 @@ class HunyuanVideo15Transformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -634,6 +633,21 @@ class HunyuanVideo15Transformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -769,6 +783,10 @@ class HunyuanVideo15Transformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, get_logger
|
||||
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -198,7 +198,6 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -218,6 +217,21 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -323,6 +337,10 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -742,7 +742,6 @@ class HunyuanImageTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -756,6 +755,21 @@ class HunyuanImageTransformer2DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if hidden_states.ndim == 4:
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
sizes = (height, width)
|
||||
@@ -886,6 +900,10 @@ class HunyuanImageTransformer2DModel(
|
||||
]
|
||||
hidden_states = hidden_states.reshape(*final_dims)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, is_torch_version, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -491,7 +491,6 @@ class LTXVideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -506,6 +505,21 @@ class LTXVideoTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
@@ -554,6 +568,10 @@ class LTXVideoTransformer3DModel(
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,14 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
is_torch_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -1094,7 +1101,6 @@ class LTX2VideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1165,6 +1171,21 @@ class LTX2VideoTransformer3DModel(
|
||||
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
|
||||
element is the denoised audio latent patch sequence.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
|
||||
@@ -1320,6 +1341,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
|
||||
audio_output = self.audio_proj_out(audio_hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output, audio_output)
|
||||
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -455,7 +455,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -465,6 +464,21 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
batch_size, _, height, width = hidden_states.shape
|
||||
|
||||
@@ -525,6 +539,10 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
@@ -404,7 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -414,6 +413,21 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
|
||||
@@ -465,6 +479,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -829,7 +829,6 @@ class QwenImageTransformer2DModel(
|
||||
self.gradient_checkpointing = False
|
||||
self.zero_cond_t = zero_cond_t
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -888,6 +887,20 @@ class QwenImageTransformer2DModel(
|
||||
"The mask-based approach is more flexible and supports variable-length sequences.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
@@ -968,6 +981,10 @@ class QwenImageTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -570,7 +570,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -583,6 +582,21 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -681,6 +695,10 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
|
||||
from ..attention_processor import (
|
||||
@@ -245,7 +245,6 @@ class SD3Transformer2DModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -285,6 +284,20 @@ class SD3Transformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
@@ -339,6 +352,10 @@ class SD3Transformer2DModel(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -630,7 +630,6 @@ class SkyReelsV2Transformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -642,6 +641,21 @@ class SkyReelsV2Transformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -757,6 +771,10 @@ class SkyReelsV2Transformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -622,7 +622,6 @@ class WanTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -632,6 +631,21 @@ class WanTransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -699,6 +713,10 @@ class WanTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -1141,7 +1141,6 @@ class WanAnimateTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1180,6 +1179,21 @@ class WanAnimateTransformer3DModel(
|
||||
Whether to return the output as a dict or tuple.
|
||||
"""
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Check that shapes match up
|
||||
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
|
||||
raise ValueError(
|
||||
@@ -1280,6 +1294,10 @@ class WanAnimateTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -261,7 +261,6 @@ class WanVACETransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -273,6 +272,21 @@ class WanVACETransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -365,6 +379,10 @@ class WanVACETransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,12 +20,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
@@ -979,7 +974,6 @@ class UNet2DConditionModel(
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
return encoder_hidden_states
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -1118,6 +1112,18 @@ class UNet2DConditionModel(
|
||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||
|
||||
# 3. down
|
||||
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
||||
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
||||
if cross_attention_kwargs is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
@@ -1233,6 +1239,10 @@ class UNet2DConditionModel(
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, deprecate, logging
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import AttentionMixin, BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
@@ -1875,7 +1875,6 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -21,7 +21,6 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale
|
||||
from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -147,7 +146,6 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
|
||||
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
|
||||
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
||||
|
||||
@@ -130,7 +130,6 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
apply_lora_scale,
|
||||
check_peft_version,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
|
||||
@@ -16,7 +16,6 @@ PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -276,59 +275,6 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
||||
|
||||
|
||||
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
|
||||
"""
|
||||
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
|
||||
|
||||
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
|
||||
pass, and ensures unscaling happens after, even if an exception occurs.
|
||||
|
||||
Args:
|
||||
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
|
||||
The name of the keyword argument that contains the LoRA scale. Common values include
|
||||
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
|
||||
"""
|
||||
|
||||
def decorator(forward_fn):
|
||||
@functools.wraps(forward_fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
from . import USE_PEFT_BACKEND
|
||||
|
||||
lora_scale = 1.0
|
||||
attention_kwargs = kwargs.get(kwargs_name)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
kwargs[kwargs_name] = attention_kwargs
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
if (
|
||||
not USE_PEFT_BACKEND
|
||||
and attention_kwargs is not None
|
||||
and attention_kwargs.get("scale", None) is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Apply LoRA scaling if using PEFT backend
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
try:
|
||||
# Execute the forward pass
|
||||
result = forward_fn(self, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
# Always unscale, even if forward pass raises an exception
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_peft_version(min_version: str) -> None:
|
||||
r"""
|
||||
Checks if the version of PEFT is compatible.
|
||||
|
||||
@@ -219,6 +219,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
@@ -13,48 +13,88 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Flux2Transformer2DModel, attention_backend
|
||||
from diffusers import Flux2Transformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
@@ -82,81 +122,244 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
# TODO (Daniel, Sayak): We can remove this test.
|
||||
def test_flux2_consistency(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model = self.model_class(**init_dict)
|
||||
# state_dict = model.state_dict()
|
||||
# for key, param in state_dict.items():
|
||||
# print(f"{key} | {param.shape}")
|
||||
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux2 Transformer."""
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
||||
# fmt: on
|
||||
|
||||
flat_output = output.cpu().flatten()
|
||||
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
||||
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux2 Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux2 Transformer."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user