Compare commits

..

12 Commits

Author SHA1 Message Date
Sayak Paul
efd6d69044 Merge branch 'main' into apply-lora-scale-decorator 2026-01-29 19:22:55 +05:30
Sayak Paul
314cfddf3a [ci] uniform run times and wheels for pytorch cuda. (#13047)
* uniform run times and wheels for pytorch cuda.

* 12.9

* change to 24.04.

* change to 24.04.
2026-01-29 19:22:30 +05:30
Sayak Paul
9b3947cf58 Merge branch 'main' into apply-lora-scale-decorator 2026-01-29 19:01:15 +05:30
sayakpaul
e5ebacb820 fix 2026-01-28 12:31:24 +05:30
sayakpaul
8c402d3a32 remove. 2026-01-28 12:16:39 +05:30
sayakpaul
458ac949a0 remove more. 2026-01-28 12:14:21 +05:30
sayakpaul
290f749bd5 up 2026-01-28 12:10:51 +05:30
sayakpaul
d6fcd78d0e apply to the rest. 2026-01-28 11:53:16 +05:30
Sayak Paul
9afafe5e26 Merge branch 'main' into apply-lora-scale-decorator 2026-01-28 09:30:36 +05:30
Sayak Paul
3cdce4d2e8 Merge branch 'main' into apply-lora-scale-decorator 2026-01-27 20:21:54 +08:00
Sayak Paul
835a087a47 Merge branch 'main' into apply-lora-scale-decorator 2026-01-20 10:44:21 +05:30
sayakpaul
afa4a23c6c feat: implement apply_lora_scale to remove boilerplate. 2026-01-19 10:04:24 +05:30
44 changed files with 163 additions and 703 deletions

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.9.0-runtime-ubuntu20.04
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
@@ -36,7 +36,8 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu129
# Install compatible versions of numba/llvmlite for Python 3.10+
RUN uv pip install --no-cache-dir \

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.9.0-runtime-ubuntu20.04
FROM nvidia/cuda:12.9.1-runtime-ubuntu24.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
@@ -36,7 +36,8 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH"
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu129
# Install compatible versions of numba/llvmlite for Python 3.10+
RUN uv pip install --no-cache-dir \

View File

@@ -106,6 +106,8 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
@@ -183,6 +185,8 @@ video, audio = pipe(
output_type="np",
return_dict=False,
)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],

View File

@@ -21,7 +21,7 @@ from torch.nn import functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import BaseOutput, logging
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -598,6 +598,7 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -20,7 +20,11 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
logging,
)
from ..attention import AttentionMixin
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
@@ -150,6 +154,7 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -197,20 +202,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if self.input_hint_block is not None:
@@ -323,10 +314,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_samples, controlnet_single_block_samples)

View File

@@ -20,7 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -123,6 +128,7 @@ class QwenImageControlNetModel(
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -181,20 +187,6 @@ class QwenImageControlNetModel(
standard_warn=False,
)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
# add
@@ -256,10 +248,6 @@ class QwenImageControlNetModel(
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return controlnet_block_samples

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ...utils import BaseOutput, apply_lora_scale, logging
from ..attention import AttentionMixin
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -117,6 +117,7 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -129,21 +130,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -218,10 +204,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if not return_dict:

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, JointTransformerBlock
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -269,6 +269,7 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
return controlnet
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -308,21 +309,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
if self.pos_embed is not None and hidden_states.ndim != 4:
raise ValueError("hidden_states must be 4D when pos_embed is used")
@@ -382,10 +368,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
# 6. scaling
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (controlnet_block_res_samples,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -397,6 +397,7 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.FloatTensor,
@@ -405,21 +406,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -486,10 +472,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
@@ -363,6 +363,7 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -374,21 +375,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
@@ -454,10 +440,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention, AttentionMixin, FeedForward
from ..attention_processor import CogVideoXAttnProcessor2_0
@@ -620,6 +620,7 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
]
)
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -632,21 +633,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# fuse clip and insightface
valid_face_emb = None
if self.is_train_face:
@@ -720,10 +706,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -20,7 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_processor import (
Attention,
@@ -414,6 +414,7 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -426,21 +427,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -527,10 +513,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -581,6 +581,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -621,20 +622,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -715,10 +702,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -22,10 +22,8 @@ from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_bria import BriaAttnProcessor
from ...utils import (
USE_PEFT_BACKEND,
apply_lora_scale,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
@@ -510,6 +508,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
]
self.caption_projection = nn.ModuleList(caption_projection)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -545,20 +544,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype)
@@ -645,10 +631,6 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
@@ -473,6 +473,7 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -511,20 +512,6 @@ class ChromaTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -631,10 +618,6 @@ class ChromaTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -638,6 +638,7 @@ class ChronoEditTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -647,21 +648,6 @@ class ChronoEditTransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -729,10 +715,6 @@ class ChronoEditTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -703,6 +703,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -718,21 +719,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
@@ -779,10 +765,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -774,6 +774,7 @@ class Flux2Transformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -810,20 +811,6 @@ class Flux2Transformer2DModel(
`tuple` where the first element is the sample tensor.
"""
# 0. Handle input arguments
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
num_txt_tokens = encoder_hidden_states.shape[1]
@@ -908,10 +895,6 @@ class Flux2Transformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
@@ -773,6 +773,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
return hidden_states, hidden_states_masks, img_sizes, img_ids
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -808,21 +809,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# spatial forward
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
@@ -933,10 +919,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
if hidden_states_masks is not None:
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -989,6 +989,7 @@ class HunyuanVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1000,21 +1001,6 @@ class HunyuanVideoTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -1104,10 +1090,6 @@ class HunyuanVideoTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -620,6 +620,7 @@ class HunyuanVideo15Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -633,21 +634,6 @@ class HunyuanVideo15Transformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -783,10 +769,6 @@ class HunyuanVideo15Transformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, get_logger
from ..cache_utils import CacheMixin
from ..embeddings import get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
@@ -198,6 +198,7 @@ class HunyuanVideoFramepackTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -217,21 +218,6 @@ class HunyuanVideoFramepackTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
@@ -337,10 +323,6 @@ class HunyuanVideoFramepackTransformer3DModel(
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer2DModelOutput(sample=hidden_states)

View File

@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -742,6 +742,7 @@ class HunyuanImageTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -755,21 +756,6 @@ class HunyuanImageTransformer2DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
if hidden_states.ndim == 4:
batch_size, channels, height, width = hidden_states.shape
sizes = (height, width)
@@ -900,10 +886,6 @@ class HunyuanImageTransformer2DModel(
]
hidden_states = hidden_states.reshape(*final_dims)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)

View File

@@ -22,7 +22,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -491,6 +491,7 @@ class LTXVideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -505,21 +506,6 @@ class LTXVideoTransformer3DModel(
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
@@ -568,10 +554,6 @@ class LTXVideoTransformer3DModel(
hidden_states = hidden_states * (1 + scale) + shift
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -22,14 +22,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import (
USE_PEFT_BACKEND,
BaseOutput,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -1101,6 +1094,7 @@ class LTX2VideoTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1171,21 +1165,6 @@ class LTX2VideoTransformer3DModel(
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
element is the denoised audio latent patch sequence.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
@@ -1341,10 +1320,6 @@ class LTX2VideoTransformer3DModel(
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
audio_output = self.audio_proj_out(audio_hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output, audio_output)
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import LuminaFeedForward
from ..attention_processor import Attention
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -455,6 +455,7 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -464,21 +465,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# 1. Condition, positional & patch embedding
batch_size, _, height, width = hidden_states.shape
@@ -539,10 +525,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
)
output = torch.stack(output, dim=0)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -21,7 +21,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
@@ -404,6 +404,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -413,21 +414,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> torch.Tensor:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
@@ -479,10 +465,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -24,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
@@ -829,6 +829,7 @@ class QwenImageTransformer2DModel(
self.gradient_checkpointing = False
self.zero_cond_t = zero_cond_t
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -887,20 +888,6 @@ class QwenImageTransformer2DModel(
"The mask-based approach is more flexible and supports variable-length sequences.",
standard_warn=False,
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.img_in(hidden_states)
@@ -981,10 +968,6 @@ class QwenImageTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -570,6 +570,7 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -582,21 +583,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
@@ -695,10 +681,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -18,7 +18,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
from ..attention_processor import (
@@ -245,6 +245,7 @@ class SD3Transformer2DModel(
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -284,20 +285,6 @@ class SD3Transformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
@@ -352,10 +339,6 @@ class SD3Transformer2DModel(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -630,6 +630,7 @@ class SkyReelsV2Transformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -641,21 +642,6 @@ class SkyReelsV2Transformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -771,10 +757,6 @@ class SkyReelsV2Transformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -622,6 +622,7 @@ class WanTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -631,21 +632,6 @@ class WanTransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -713,10 +699,6 @@ class WanTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -1141,6 +1141,7 @@ class WanAnimateTransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -1179,21 +1180,6 @@ class WanAnimateTransformer3DModel(
Whether to return the output as a dict or tuple.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
# Check that shapes match up
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
raise ValueError(
@@ -1294,10 +1280,6 @@ class WanAnimateTransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
@@ -261,6 +261,7 @@ class WanVACETransformer3DModel(
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -272,21 +273,6 @@ class WanVACETransformer3DModel(
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
@@ -379,10 +365,6 @@ class WanVACETransformer3DModel(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -20,7 +20,12 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import (
BaseOutput,
apply_lora_scale,
deprecate,
logging,
)
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import (
@@ -974,6 +979,7 @@ class UNet2DConditionModel(
encoder_hidden_states = (encoder_hidden_states, image_embeds)
return encoder_hidden_states
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,
@@ -1112,18 +1118,6 @@ class UNet2DConditionModel(
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
is_adapter = down_intrablock_additional_residuals is not None
@@ -1239,10 +1233,6 @@ class UNet2DConditionModel(
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging
from ...utils import BaseOutput, apply_lora_scale, deprecate, logging
from ...utils.torch_utils import apply_freeu
from ..attention import AttentionMixin, BasicTransformerBlock
from ..attention_processor import (
@@ -1875,6 +1875,7 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@apply_lora_scale("cross_attention_kwargs")
def forward(
self,
sample: torch.Tensor,

View File

@@ -21,6 +21,7 @@ from torch.utils.checkpoint import checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import apply_lora_scale
from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
@@ -146,6 +147,7 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
self.gradient_checkpointing = False
@apply_lora_scale("cross_attention_kwargs")
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)

View File

@@ -13,14 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator, Iterator
from fractions import Fraction
from typing import List, Optional, Tuple, Union
from typing import Optional
import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
from ...utils import is_av_available
@@ -105,52 +101,11 @@ def _write_audio(
def encode_video(
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
fps: int,
audio: Optional[torch.Tensor],
audio_sample_rate: Optional[int],
output_path: str,
video_chunks_number: int = 1,
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
"""
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
video_np = video.cpu().numpy()
Args:
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
usually return with `output_type="np"`).
fps (`int`)
The frames per second (FPS) of the encoded video.
audio (`torch.Tensor`, *optional*):
An audio waveform of shape [audio_channels, samples].
audio_sample_rate: (`int`, *optional*):
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
output_path (`str`):
The path to save the encoded video to.
video_chunks_number (`int`, *optional*, defaults to `1`):
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
number of chunks to use often depends on the tiling config for the video VAE.
"""
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
# Pipeline output_type="pil"
video_frames = [np.array(frame) for frame in video]
video = np.stack(video_frames, axis=0)
video = torch.from_numpy(video)
elif isinstance(video, np.ndarray):
# Pipeline output_type="np"
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
if np.all(is_denormalized):
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
if isinstance(video, torch.Tensor):
video = iter([video])
first_chunk = next(video)
_, height, width, _ = first_chunk.shape
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
@@ -164,18 +119,10 @@ def encode_video(
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
def all_tiles(
first_chunk: torch.Tensor, tiles_generator: Generator[Tuple[torch.Tensor, int], None, None]
) -> Generator[Tuple[torch.Tensor, int], None, None]:
yield first_chunk
yield from tiles_generator
for video_chunk in tqdm(all_tiles(first_chunk, video), total=video_chunks_number):
video_chunk_cpu = video_chunk.to("cpu").numpy()
for frame_array in video_chunk_cpu:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():

View File

@@ -69,6 +69,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -75,6 +75,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -76,6 +76,8 @@ EXAMPLE_DOC_STRING = """
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],

View File

@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from typing import Optional
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.