Compare commits

..

1 Commits

Author SHA1 Message Date
Sayak Paul
8abcf351c9 feat: implement apply_lora_scale to remove boilerplate. (#12994)
* feat: implement apply_lora_scale to remove boilerplate.

* apply to the rest.

* up

* remove more.

* remove.

* fix

* apply feedback.
2026-02-13 23:25:46 +05:30
39 changed files with 157 additions and 705 deletions

View File

@@ -21,7 +21,7 @@ from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -598,6 +598,7 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -20,7 +20,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
logging,
)
from ..attention import AttentionMixin
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
@@ -150,6 +154,7 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -197,20 +202,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
@@ -323,10 +314,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)

View File

@@ -20,7 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -123,6 +128,7 @@ class QwenImageControlNetModel(
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -181,20 +187,6 @@ class QwenImageControlNetModel(
standard_warn=False,
)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
# add
@@ -256,10 +248,6 @@ class QwenImageControlNetModel(
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return controlnet_block_samples

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -117,6 +117,7 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -129,21 +130,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -218,10 +204,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, JointTransformerBlock
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -269,6 +269,7 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -308,21 +309,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")
@@ -382,10 +368,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_res_samples,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -397,6 +397,7 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -405,21 +406,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -486,10 +472,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -363,6 +363,7 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -374,21 +375,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
@@ -454,10 +440,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0
@@ -620,6 +620,7 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
]
)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -632,21 +633,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
id_vit_hidden: torch.Tensor | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# fuse clip and insightface
valid_face_emb = None
if self.is_train_face:
@@ -720,10 +706,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_processor import (
Attention,
@@ -414,6 +414,7 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -426,21 +427,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
controlnet_block_samples: tuple[torch.Tensor] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -527,10 +513,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -581,6 +581,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -621,20 +622,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -715,10 +702,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -22,10 +22,8 @@ from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
apply_lora_scale,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
@@ -510,6 +508,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
]
self.caption_projection = nn.ModuleList(caption_projection)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -545,20 +544,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -645,10 +631,6 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
@@ -473,6 +473,7 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -511,20 +512,6 @@ class ChromaTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -631,10 +618,6 @@ class ChromaTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -641,6 +641,7 @@ class ChronoEditTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -650,21 +651,6 @@ class ChronoEditTransformer3DModel(
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -732,10 +718,6 @@ class ChronoEditTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -699,6 +699,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -712,21 +713,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
attention_mask: torch.Tensor | None = None,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | list[tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
@@ -773,10 +759,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -774,6 +774,7 @@ class Flux2Transformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -810,20 +811,6 @@ class Flux2Transformer2DModel(
`tuple` where the first element is the sample tensor.
"""
# 0. Handle input arguments
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
num_txt_tokens = encoder_hidden_states.shape[1]
@@ -908,10 +895,6 @@ class Flux2Transformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
@@ -773,6 +773,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
return hidden_states, hidden_states_masks, img_sizes, img_ids
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -808,21 +809,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
@@ -933,10 +919,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -989,6 +989,7 @@ class HunyuanVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1000,21 +1001,6 @@ class HunyuanVideoTransformer3DModel(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -1104,10 +1090,6 @@ class HunyuanVideoTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -620,6 +620,7 @@ class HunyuanVideo15Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -632,22 +633,7 @@ class HunyuanVideo15Transformer3DModel(
image_embeds: torch.Tensor | None = None,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -783,10 +769,6 @@ class HunyuanVideo15Transformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, get_logger
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
@@ -198,6 +198,7 @@ class HunyuanVideoFramepackTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -217,21 +218,6 @@ class HunyuanVideoFramepackTransformer3DModel(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor] | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -337,10 +323,6 @@ class HunyuanVideoFramepackTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)

View File

@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -740,6 +740,7 @@ class HunyuanImageTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -753,21 +754,6 @@ class HunyuanImageTransformer2DModel(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor | dict[str, torch.Tensor]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
@@ -898,10 +884,6 @@ class HunyuanImageTransformer2DModel(
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -491,6 +491,7 @@ class LTXVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -505,21 +506,6 @@ class LTXVideoTransformer3DModel(
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
@@ -568,10 +554,6 @@ class LTXVideoTransformer3DModel(
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,14 +22,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -1099,6 +1092,7 @@ class LTX2VideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1169,21 +1163,6 @@ class LTX2VideoTransformer3DModel(
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
element is the denoised audio latent patch sequence.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -1339,10 +1318,6 @@ class LTX2VideoTransformer3DModel(
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
audio_output = self.audio_proj_out(audio_hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, audio_output)
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -455,6 +455,7 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -464,21 +465,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor | Transformer2DModelOutput:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# 1. Condition, positional & patch embedding
batch_size, _, height, width = hidden_states.shape
@@ -539,10 +525,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
)
output = torch.stack(output, dim=0)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
@@ -404,6 +404,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -413,21 +414,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
@@ -479,10 +465,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
@@ -829,6 +829,7 @@ class QwenImageTransformer2DModel(
self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -887,20 +888,6 @@ class QwenImageTransformer2DModel(
"The mask-based approach is more flexible and supports variable-length sequences.",
standard_warn=False,
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
@@ -981,10 +968,6 @@ class QwenImageTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -570,6 +570,7 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -581,22 +582,7 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
attention_kwargs: dict[str, Any] | None = None,
controlnet_block_samples: tuple[torch.Tensor] | None = None,
return_dict: bool = True,
) -> tuple[torch.Tensor, ..., Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
) -> tuple[torch.Tensor, ...] | Transformer2DModelOutput:
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -695,10 +681,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -18,7 +18,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
from ..attention_processor import (
@@ -245,6 +245,7 @@ class SD3Transformer2DModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -284,20 +285,6 @@ class SD3Transformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
@@ -352,10 +339,6 @@ class SD3Transformer2DModel(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -630,6 +630,7 @@ class SkyReelsV2Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -641,21 +642,6 @@ class SkyReelsV2Transformer3DModel(
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -771,10 +757,6 @@ class SkyReelsV2Transformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -625,6 +625,7 @@ class WanTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -634,21 +635,6 @@ class WanTransformer3DModel(
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -716,10 +702,6 @@ class WanTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -1147,6 +1147,7 @@ class WanAnimateTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1185,21 +1186,6 @@ class WanAnimateTransformer3DModel(
Whether to return the output as a dict or tuple.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Check that shapes match up
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
raise ValueError(
@@ -1300,10 +1286,6 @@ class WanAnimateTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
@@ -263,6 +263,7 @@ class WanVACETransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -274,21 +275,6 @@ class WanVACETransformer3DModel(
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -381,10 +367,6 @@ class WanVACETransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -972,6 +977,7 @@ class UNet2DConditionModel(
encoder_hidden_states = (encoder_hidden_states, image_embeds)
return encoder_hidden_states
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,
@@ -1110,18 +1116,6 @@ class UNet2DConditionModel(
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
@@ -1237,10 +1231,6 @@ class UNet2DConditionModel(
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ...utils import BaseOutput, apply_lora_scale, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import (
@@ -1875,6 +1875,7 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -21,6 +21,7 @@ from torch.utils.checkpoint import checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale
from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -146,6 +147,7 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
@apply_lora_scale("cross_attention_kwargs")
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)

View File

@@ -131,6 +131,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from packaging import version
@@ -274,6 +275,55 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
if not USE_PEFT_BACKEND and lora_scale != 1.0:
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.

View File

@@ -21,8 +21,11 @@ import torch
from diffusers import BitsAndBytesConfig, GGUFQuantizationConfig, NVIDIAModelOptConfig, QuantoConfig, TorchAoConfig
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_nvidia_modelopt_available,
is_optimum_quanto_available,
is_torchao_available,
is_torchao_version,
)
from ...testing_utils import (
@@ -56,6 +59,13 @@ if is_bitsandbytes_available():
if is_optimum_quanto_available():
from optimum.quanto import QLinear
if is_gguf_available():
pass
if is_torchao_available():
if is_torchao_version(">=", "0.9.0"):
pass
class LoRALayer(torch.nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only.
@@ -122,14 +132,14 @@ class QuantizationTesterMixin:
def _verify_if_layer_quantized(self, name, module, config_kwargs):
raise NotImplementedError("Subclass must implement _verify_if_layer_quantized")
def _is_module_quantized(self, module, quant_config_kwargs=None):
def _is_module_quantized(self, module):
"""
Check if a module is quantized. Returns True if quantized, False otherwise.
Default implementation tries _verify_if_layer_quantized and catches exceptions.
Subclasses can override for more efficient checking.
"""
try:
self._verify_if_layer_quantized("", module, quant_config_kwargs or {})
self._verify_if_layer_quantized("", module, {})
return True
except (AssertionError, AttributeError):
return False
@@ -263,9 +273,7 @@ class QuantizationTesterMixin:
f"Quantized layer count mismatch: expected {expected_quantized_layers}, got {num_quantized_layers} (total linear layers: {num_linear_layers}, FP32 modules: {num_fp32_modules})"
)
def _test_quantization_modules_to_not_convert(
self, config_kwargs, modules_to_not_convert, to_not_convert_key="modules_to_not_convert"
):
def _test_quantization_modules_to_not_convert(self, config_kwargs, modules_to_not_convert):
"""
Test that modules specified in modules_to_not_convert are not quantized.
@@ -275,7 +283,7 @@ class QuantizationTesterMixin:
"""
# Create config with modules_to_not_convert
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion[to_not_convert_key] = modules_to_not_convert
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_not_convert
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
@@ -287,7 +295,7 @@ class QuantizationTesterMixin:
if any(excluded in name for excluded in modules_to_not_convert):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
assert not self._is_module_quantized(module), (
f"Module {name} should not be quantized but was found to be quantized"
)
@@ -299,7 +307,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_not_convert):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
if self._is_module_quantized(module):
found_quantized = True
break
@@ -604,7 +612,7 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
pytest.skip("modules_to_not_convert_for_test not defined for this model")
self._test_quantization_modules_to_not_convert(
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude, "llm_int8_skip_modules"
BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"], modules_to_exclude
)
@pytest.mark.parametrize("config_name", ["4bit_nf4", "8bit"], ids=["4bit_nf4", "8bit"])
@@ -803,14 +811,7 @@ class TorchAoConfigMixin:
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)
def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
# Check if the weight is actually quantized
weight = module.weight
is_quantized = isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))
assert is_quantized, f"Layer {name} weight is not quantized, got {type(weight)}"
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
@@ -906,39 +907,9 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
if modules_to_exclude is None:
pytest.skip("modules_to_not_convert_for_test not defined for this model")
# Custom implementation for torchao that skips memory footprint check
# because get_memory_footprint() doesn't accurately reflect torchao quantization
config_kwargs = TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"]
config_kwargs_with_exclusion = config_kwargs.copy()
config_kwargs_with_exclusion["modules_to_not_convert"] = modules_to_exclude
model_with_exclusion = self._create_quantized_model(config_kwargs_with_exclusion)
# Find a module that should NOT be quantized
found_excluded = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is in the exclusion list
if any(excluded in name for excluded in modules_to_exclude):
found_excluded = True
# This module should NOT be quantized
assert not self._is_module_quantized(module, config_kwargs_with_exclusion), (
f"Module {name} should not be quantized but was found to be quantized"
)
assert found_excluded, f"No linear layers found in excluded modules: {modules_to_exclude}"
# Find a module that SHOULD be quantized (not in exclusion list)
found_quantized = False
for name, module in model_with_exclusion.named_modules():
if isinstance(module, torch.nn.Linear):
# Check if this module is NOT in the exclusion list
if not any(excluded in name for excluded in modules_to_exclude):
if self._is_module_quantized(module, config_kwargs_with_exclusion):
found_quantized = True
break
assert found_quantized, "No quantized layers found outside of excluded modules"
self._test_quantization_modules_to_not_convert(
TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"], modules_to_exclude
)
def test_torchao_device_map(self):
"""Test that device_map='auto' works correctly with quantization."""

View File

@@ -318,10 +318,6 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""
@@ -334,18 +330,10 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
def pretrained_model_kwargs(self):
return {}
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
@@ -414,10 +402,6 @@ class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTes
class TestFluxTransformerModelOpt(FluxTransformerTesterConfig, ModelOptTesterMixin):
"""ModelOpt quantization tests for Flux Transformer."""
@property
def modules_to_not_convert_for_test(self):
return ["norm_out.linear"]
class TestFluxTransformerModelOptCompile(FluxTransformerTesterConfig, ModelOptCompileTesterMixin):
"""ModelOpt + compile tests for Flux Transformer."""