mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-06 19:05:11 +08:00
Compare commits
2 Commits
apply-lora
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8da128067c | ||
|
|
1b8fc6c589 |
@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
||||
@@ -21,7 +21,7 @@ from torch.nn import functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -598,7 +598,6 @@ class ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, FromOriginalModel
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -20,11 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
|
||||
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
||||
@@ -154,7 +150,6 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -202,6 +197,20 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if self.input_hint_block is not None:
|
||||
@@ -314,6 +323,10 @@ class FluxControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_samples, controlnet_single_block_samples)
|
||||
|
||||
|
||||
@@ -20,12 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -128,7 +123,6 @@ class QwenImageControlNetModel(
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -187,6 +181,20 @@ class QwenImageControlNetModel(
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
# add
|
||||
@@ -248,6 +256,10 @@ class QwenImageControlNetModel(
|
||||
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
||||
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return controlnet_block_samples
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -117,7 +117,6 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -130,6 +129,21 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -204,6 +218,10 @@ class SanaControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMi
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, JointTransformerBlock
|
||||
from ..attention_processor import Attention, FusedJointAttnProcessor2_0
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
@@ -269,7 +269,6 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
|
||||
return controlnet
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -309,6 +308,21 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if self.pos_embed is not None and hidden_states.ndim != 4:
|
||||
raise ValueError("hidden_states must be 4D when pos_embed is used")
|
||||
|
||||
@@ -368,6 +382,10 @@ class SD3ControlNetModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMix
|
||||
# 6. scaling
|
||||
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (controlnet_block_res_samples,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
@@ -397,7 +397,6 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
@@ -406,6 +405,21 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
# Apply patch embedding, timestep embedding, and project the caption embeddings.
|
||||
@@ -472,6 +486,10 @@ class AuraFlowTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin, FeedForward
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
||||
@@ -363,7 +363,6 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -375,6 +374,21 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_frames, channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. Time embedding
|
||||
@@ -440,6 +454,10 @@ class CogVideoXTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftA
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention, AttentionMixin, FeedForward
|
||||
from ..attention_processor import CogVideoXAttnProcessor2_0
|
||||
@@ -620,7 +620,6 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
]
|
||||
)
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -633,6 +632,21 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
id_vit_hidden: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# fuse clip and insightface
|
||||
valid_face_emb = None
|
||||
if self.is_train_face:
|
||||
@@ -706,6 +720,10 @@ class ConsisIDTransformer3DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAd
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
Attention,
|
||||
@@ -414,7 +414,6 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -427,6 +426,21 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -513,6 +527,10 @@ class SanaTransformer2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapte
|
||||
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
|
||||
output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -581,7 +581,6 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -622,6 +621,20 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
@@ -702,6 +715,10 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -22,8 +22,10 @@ from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_bria import BriaAttnProcessor
|
||||
from ...utils import (
|
||||
apply_lora_scale,
|
||||
USE_PEFT_BACKEND,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
@@ -508,7 +510,6 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
]
|
||||
self.caption_projection = nn.ModuleList(caption_projection)
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -544,7 +545,20 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
@@ -631,6 +645,10 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.import_utils import is_torch_npu_available
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -473,7 +473,6 @@ class ChromaTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -512,6 +511,20 @@ class ChromaTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -618,6 +631,10 @@ class ChromaTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -638,7 +638,6 @@ class ChronoEditTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -648,6 +647,21 @@ class ChronoEditTransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -715,6 +729,10 @@ class ChronoEditTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
@@ -703,7 +703,6 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -719,6 +718,21 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
@@ -765,6 +779,10 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -634,7 +634,6 @@ class FluxTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -676,6 +675,20 @@ class FluxTransformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
@@ -772,6 +785,10 @@ class FluxTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -774,7 +774,6 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -811,6 +810,20 @@ class Flux2Transformer2DModel(
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
@@ -895,6 +908,10 @@ class Flux2Transformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -773,7 +773,6 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
|
||||
return hidden_states, hidden_states_masks, img_sizes, img_ids
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -809,6 +808,21 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||
)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# spatial forward
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
@@ -919,6 +933,10 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
if hidden_states_masks is not None:
|
||||
hidden_states_masks = hidden_states_masks[:, :image_tokens_seq_len]
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -989,7 +989,6 @@ class HunyuanVideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1001,6 +1000,21 @@ class HunyuanVideoTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -1090,6 +1104,10 @@ class HunyuanVideoTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -620,7 +620,6 @@ class HunyuanVideo15Transformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -634,6 +633,21 @@ class HunyuanVideo15Transformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size_t, self.config.patch_size, self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -769,6 +783,10 @@ class HunyuanVideo15Transformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, get_logger
|
||||
from ...utils import USE_PEFT_BACKEND, get_logger, scale_lora_layers, unscale_lora_layers
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import get_1d_rotary_pos_embed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -198,7 +198,6 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -218,6 +217,21 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p, p_t = self.config.patch_size, self.config.patch_size_t
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -323,6 +337,10 @@ class HunyuanVideoFramepackTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
return Transformer2DModelOutput(sample=hidden_states)
|
||||
|
||||
@@ -23,7 +23,7 @@ from diffusers.loaders import FromOriginalModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -742,7 +742,6 @@ class HunyuanImageTransformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -756,6 +755,21 @@ class HunyuanImageTransformer2DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
if hidden_states.ndim == 4:
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
sizes = (height, width)
|
||||
@@ -886,6 +900,10 @@ class HunyuanImageTransformer2DModel(
|
||||
]
|
||||
hidden_states = hidden_states.reshape(*final_dims)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, is_torch_version, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -491,7 +491,6 @@ class LTXVideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -506,6 +505,21 @@ class LTXVideoTransformer3DModel(
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
image_rotary_emb = self.rope(hidden_states, num_frames, height, width, rope_interpolation_scale, video_coords)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
@@ -554,6 +568,10 @@ class LTXVideoTransformer3DModel(
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -22,7 +22,14 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, is_torch_version, logging
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
is_torch_version,
|
||||
logging,
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -1094,7 +1101,6 @@ class LTX2VideoTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1165,6 +1171,21 @@ class LTX2VideoTransformer3DModel(
|
||||
`tuple` is returned where the first element is the denoised video latent patch sequence and the second
|
||||
element is the denoised audio latent patch sequence.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
|
||||
@@ -1320,6 +1341,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
|
||||
audio_output = self.audio_proj_out(audio_hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output, audio_output)
|
||||
return AudioVisualModelOutput(sample=output, audio_sample=audio_output)
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -455,7 +455,6 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -465,6 +464,21 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# 1. Condition, positional & patch embedding
|
||||
batch_size, _, height, width = hidden_states.shape
|
||||
|
||||
@@ -525,6 +539,10 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromO
|
||||
)
|
||||
output = torch.stack(output, dim=0)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
|
||||
@@ -404,7 +404,6 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -414,6 +413,21 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
|
||||
@@ -465,6 +479,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOri
|
||||
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -24,7 +24,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
@@ -829,7 +829,6 @@ class QwenImageTransformer2DModel(
|
||||
self.gradient_checkpointing = False
|
||||
self.zero_cond_t = zero_cond_t
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -888,6 +887,20 @@ class QwenImageTransformer2DModel(
|
||||
"The mask-based approach is more flexible and supports variable-length sequences.",
|
||||
standard_warn=False,
|
||||
)
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
@@ -968,6 +981,10 @@ class QwenImageTransformer2DModel(
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
@@ -570,7 +570,6 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -583,6 +582,21 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
@@ -681,6 +695,10 @@ class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, FeedForward, JointTransformerBlock
|
||||
from ..attention_processor import (
|
||||
@@ -245,7 +245,6 @@ class SD3Transformer2DModel(
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -285,6 +284,20 @@ class SD3Transformer2DModel(
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
@@ -339,6 +352,10 @@ class SD3Transformer2DModel(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -630,7 +630,6 @@ class SkyReelsV2Transformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -642,6 +641,21 @@ class SkyReelsV2Transformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -757,6 +771,10 @@ class SkyReelsV2Transformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -622,7 +622,6 @@ class WanTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -632,6 +631,21 @@ class WanTransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -699,6 +713,10 @@ class WanTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
@@ -1141,7 +1141,6 @@ class WanAnimateTransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -1180,6 +1179,21 @@ class WanAnimateTransformer3DModel(
|
||||
Whether to return the output as a dict or tuple.
|
||||
"""
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Check that shapes match up
|
||||
if pose_hidden_states is not None and pose_hidden_states.shape[2] + 1 != hidden_states.shape[2]:
|
||||
raise ValueError(
|
||||
@@ -1280,6 +1294,10 @@ class WanAnimateTransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
@@ -261,7 +261,6 @@ class WanVACETransformer3DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -273,6 +272,21 @@ class WanVACETransformer3DModel(
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
||||
logger.warning(
|
||||
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
post_patch_num_frames = num_frames // p_t
|
||||
@@ -365,6 +379,10 @@ class WanVACETransformer3DModel(
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
|
||||
@@ -20,12 +20,7 @@ import torch.nn as nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
apply_lora_scale,
|
||||
deprecate,
|
||||
logging,
|
||||
)
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import (
|
||||
@@ -979,7 +974,6 @@ class UNet2DConditionModel(
|
||||
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
||||
return encoder_hidden_states
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
@@ -1118,6 +1112,18 @@ class UNet2DConditionModel(
|
||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||
|
||||
# 3. down
|
||||
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
||||
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
||||
if cross_attention_kwargs is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
@@ -1233,6 +1239,10 @@ class UNet2DConditionModel(
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput, apply_lora_scale, deprecate, logging
|
||||
from ...utils import BaseOutput, deprecate, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import AttentionMixin, BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
@@ -1875,7 +1875,6 @@ class UNetMotionModel(ModelMixin, AttentionMixin, ConfigMixin, UNet2DConditionLo
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
|
||||
@@ -21,7 +21,6 @@ from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale
|
||||
from ..attention import AttentionMixin, BasicTransformerBlock, SkipFFTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
@@ -147,7 +146,6 @@ class UVit2DModel(ModelMixin, AttentionMixin, ConfigMixin, PeftAdapterMixin):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("cross_attention_kwargs")
|
||||
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None):
|
||||
encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
|
||||
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)
|
||||
|
||||
@@ -34,6 +34,7 @@ from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
MODULAR_MODEL_CARD_TEMPLATE,
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
@@ -41,6 +42,7 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
generate_modular_model_card_content,
|
||||
make_doc_string,
|
||||
)
|
||||
|
||||
@@ -1753,9 +1755,19 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Generate modular pipeline card content
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id,
|
||||
token=token,
|
||||
is_pipeline=True,
|
||||
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
|
||||
is_modular=True,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
|
||||
@@ -31,6 +31,30 @@ if is_torch_available():
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# Template for modular pipeline model card description with placeholders
|
||||
MODULAR_MODEL_CARD_TEMPLATE = """{model_description}
|
||||
|
||||
## Example Usage
|
||||
|
||||
[TODO]
|
||||
|
||||
## Pipeline Architecture
|
||||
|
||||
This modular pipeline is composed of the following blocks:
|
||||
|
||||
{blocks_description} {trigger_inputs_section}
|
||||
|
||||
## Model Components
|
||||
|
||||
{components_description} {configs_section}
|
||||
|
||||
## Input/Output Specification
|
||||
|
||||
### Inputs {inputs_description}
|
||||
|
||||
### Outputs {outputs_description}
|
||||
"""
|
||||
|
||||
|
||||
class InsertableDict(OrderedDict):
|
||||
def insert(self, key, value, index):
|
||||
@@ -916,3 +940,178 @@ def make_doc_string(
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_modular_model_card_content(blocks) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate model card content for a modular pipeline.
|
||||
|
||||
This function creates a comprehensive model card with descriptions of the pipeline's architecture, components,
|
||||
configurations, inputs, and outputs.
|
||||
|
||||
Args:
|
||||
blocks: The pipeline's blocks object containing all pipeline specifications
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing formatted content sections:
|
||||
- pipeline_name: Name of the pipeline
|
||||
- model_description: Overall description with pipeline type
|
||||
- blocks_description: Detailed architecture of blocks
|
||||
- components_description: List of required components
|
||||
- configs_section: Configuration parameters section
|
||||
- inputs_description: Input parameters specification
|
||||
- outputs_description: Output parameters specification
|
||||
- trigger_inputs_section: Conditional execution information
|
||||
- tags: List of relevant tags for the model card
|
||||
"""
|
||||
blocks_class_name = blocks.__class__.__name__
|
||||
pipeline_name = blocks_class_name.replace("Blocks", " Pipeline")
|
||||
description = getattr(blocks, "description", "A modular diffusion pipeline.")
|
||||
|
||||
# generate blocks architecture description
|
||||
blocks_desc_parts = []
|
||||
sub_blocks = getattr(blocks, "sub_blocks", None) or {}
|
||||
if sub_blocks:
|
||||
for i, (name, block) in enumerate(sub_blocks.items()):
|
||||
block_class = block.__class__.__name__
|
||||
block_desc = block.description.split("\n")[0] if getattr(block, "description", "") else ""
|
||||
blocks_desc_parts.append(f"{i + 1}. **{name}** (`{block_class}`)")
|
||||
if block_desc:
|
||||
blocks_desc_parts.append(f" - {block_desc}")
|
||||
|
||||
# add sub-blocks if any
|
||||
if hasattr(block, "sub_blocks") and block.sub_blocks:
|
||||
for sub_name, sub_block in block.sub_blocks.items():
|
||||
sub_class = sub_block.__class__.__name__
|
||||
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
|
||||
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
|
||||
if sub_desc:
|
||||
blocks_desc_parts.append(f" - {sub_desc}")
|
||||
|
||||
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
|
||||
|
||||
components = getattr(blocks, "expected_components", [])
|
||||
if components:
|
||||
components_str = format_components(components, indent_level=0, add_empty_lines=False)
|
||||
# remove the "Components:" header since template has its own
|
||||
components_description = components_str.replace("Components:\n", "").strip()
|
||||
if components_description:
|
||||
# Convert to enumerated list
|
||||
lines = [line.strip() for line in components_description.split("\n") if line.strip()]
|
||||
enumerated_lines = [f"{i + 1}. {line}" for i, line in enumerate(lines)]
|
||||
components_description = "\n".join(enumerated_lines)
|
||||
else:
|
||||
components_description = "No specific components required."
|
||||
else:
|
||||
components_description = "No specific components required. Components can be loaded dynamically."
|
||||
|
||||
configs = getattr(blocks, "expected_configs", [])
|
||||
configs_section = ""
|
||||
if configs:
|
||||
configs_str = format_configs(configs, indent_level=0, add_empty_lines=False)
|
||||
configs_description = configs_str.replace("Configs:\n", "").strip()
|
||||
if configs_description:
|
||||
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
|
||||
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
|
||||
# format inputs as markdown list
|
||||
inputs_parts = []
|
||||
required_inputs = [inp for inp in inputs if inp.required]
|
||||
optional_inputs = [inp for inp in inputs if not inp.required]
|
||||
|
||||
if required_inputs:
|
||||
inputs_parts.append("**Required:**\n")
|
||||
for inp in required_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
|
||||
|
||||
if optional_inputs:
|
||||
if required_inputs:
|
||||
inputs_parts.append("")
|
||||
inputs_parts.append("**Optional:**\n")
|
||||
for inp in optional_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
|
||||
|
||||
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
|
||||
|
||||
# format outputs as markdown list
|
||||
outputs_parts = []
|
||||
for out in outputs:
|
||||
if hasattr(out.type_hint, "__name__"):
|
||||
type_str = out.type_hint.__name__
|
||||
elif out.type_hint is not None:
|
||||
type_str = str(out.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = out.description or "No description provided"
|
||||
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
|
||||
|
||||
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
### Conditional Execution
|
||||
|
||||
This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
- **Trigger Inputs**: {trigger_inputs_str}
|
||||
"""
|
||||
|
||||
# generate tags based on pipeline characteristics
|
||||
tags = ["modular-diffusers", "diffusers"]
|
||||
|
||||
if hasattr(blocks, "model_name") and blocks.model_name:
|
||||
tags.append(blocks.model_name)
|
||||
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
triggers = blocks.trigger_inputs
|
||||
if any(t in triggers for t in ["mask", "mask_image"]):
|
||||
tags.append("inpainting")
|
||||
if any(t in triggers for t in ["image", "image_latents"]):
|
||||
tags.append("image-to-image")
|
||||
if any(t in triggers for t in ["control_image", "controlnet_cond"]):
|
||||
tags.append("controlnet")
|
||||
if not any(t in triggers for t in ["image", "mask", "image_latents", "mask_image"]):
|
||||
tags.append("text-to-image")
|
||||
else:
|
||||
tags.append("text-to-image")
|
||||
|
||||
block_count = len(blocks.sub_blocks)
|
||||
model_description = f"""This is a modular diffusion pipeline built with 🧨 Diffusers' modular pipeline framework.
|
||||
|
||||
**Pipeline Type**: {blocks_class_name}
|
||||
|
||||
**Description**: {description}
|
||||
|
||||
This pipeline uses a {block_count}-block architecture that can be customized and extended."""
|
||||
|
||||
return {
|
||||
"pipeline_name": pipeline_name,
|
||||
"model_description": model_description,
|
||||
"blocks_description": blocks_description,
|
||||
"components_description": components_description,
|
||||
"configs_section": configs_section,
|
||||
"inputs_description": inputs_description,
|
||||
"outputs_description": outputs_description,
|
||||
"trigger_inputs_section": trigger_inputs_section,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@@ -131,7 +131,6 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
apply_lora_scale,
|
||||
check_peft_version,
|
||||
delete_adapter_layers,
|
||||
get_adapter_name,
|
||||
|
||||
@@ -107,6 +107,7 @@ def load_or_create_model_card(
|
||||
license: Optional[str] = None,
|
||||
widget: Optional[List[dict]] = None,
|
||||
inference: Optional[bool] = None,
|
||||
is_modular: bool = False,
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Loads or creates a model card.
|
||||
@@ -131,6 +132,8 @@ def load_or_create_model_card(
|
||||
widget (`List[dict]`, *optional*): Widget to accompany a gallery template.
|
||||
inference: (`bool`, optional): Whether to turn on inference widget. Helpful when using
|
||||
`load_or_create_model_card` from a training script.
|
||||
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
|
||||
When True, uses model_description as-is without additional template formatting.
|
||||
"""
|
||||
if not is_jinja_available():
|
||||
raise ValueError(
|
||||
@@ -159,10 +162,14 @@ def load_or_create_model_card(
|
||||
)
|
||||
else:
|
||||
card_data = ModelCardData()
|
||||
component = "pipeline" if is_pipeline else "model"
|
||||
if model_description is None:
|
||||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
if is_modular and model_description is not None:
|
||||
model_card = ModelCard(model_description)
|
||||
model_card.data = card_data
|
||||
else:
|
||||
component = "pipeline" if is_pipeline else "model"
|
||||
if model_description is None:
|
||||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ PEFT utilities: Utilities related to peft library
|
||||
"""
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import importlib
|
||||
from typing import Optional
|
||||
|
||||
@@ -276,59 +275,6 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
|
||||
module.set_scale(adapter_name, get_module_weight(weight, module_name))
|
||||
|
||||
|
||||
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
|
||||
"""
|
||||
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
|
||||
|
||||
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
|
||||
pass, and ensures unscaling happens after, even if an exception occurs.
|
||||
|
||||
Args:
|
||||
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
|
||||
The name of the keyword argument that contains the LoRA scale. Common values include
|
||||
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
|
||||
"""
|
||||
|
||||
def decorator(forward_fn):
|
||||
@functools.wraps(forward_fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
from . import USE_PEFT_BACKEND
|
||||
|
||||
lora_scale = 1.0
|
||||
attention_kwargs = kwargs.get(kwargs_name)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
kwargs[kwargs_name] = attention_kwargs
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
if (
|
||||
not USE_PEFT_BACKEND
|
||||
and attention_kwargs is not None
|
||||
and attention_kwargs.get("scale", None) is not None
|
||||
):
|
||||
logger.warning(
|
||||
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
# Apply LoRA scaling if using PEFT backend
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
try:
|
||||
# Execute the forward pass
|
||||
result = forward_fn(self, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
# Always unscale, even if forward pass raises an exception
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_peft_version(min_version: str) -> None:
|
||||
r"""
|
||||
Checks if the version of PEFT is compatible.
|
||||
|
||||
@@ -8,6 +8,13 @@ import torch
|
||||
import diffusers
|
||||
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
InputParam,
|
||||
OutputParam,
|
||||
generate_modular_model_card_content,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
|
||||
@@ -335,3 +342,239 @@ class ModularGuiderTesterMixin:
|
||||
assert out_cfg.shape == out_no_cfg.shape
|
||||
max_diff = torch.abs(out_cfg - out_no_cfg).max()
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
def __init__(self, name, description):
|
||||
self.__class__.__name__ = name
|
||||
self.description = description
|
||||
self.sub_blocks = {}
|
||||
|
||||
return MockBlock(name, description)
|
||||
|
||||
def create_mock_blocks(
|
||||
self,
|
||||
class_name="TestBlocks",
|
||||
description="Test pipeline description",
|
||||
num_blocks=2,
|
||||
components=None,
|
||||
configs=None,
|
||||
inputs=None,
|
||||
outputs=None,
|
||||
trigger_inputs=None,
|
||||
model_name=None,
|
||||
):
|
||||
class MockBlocks:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = class_name
|
||||
self.description = description
|
||||
self.sub_blocks = {}
|
||||
self.expected_components = components or []
|
||||
self.expected_configs = configs or []
|
||||
self.inputs = inputs or []
|
||||
self.outputs = outputs or []
|
||||
self.trigger_inputs = trigger_inputs
|
||||
self.model_name = model_name
|
||||
|
||||
blocks = MockBlocks()
|
||||
|
||||
# Add mock sub-blocks
|
||||
for i in range(num_blocks):
|
||||
block_name = f"block_{i}"
|
||||
blocks.sub_blocks[block_name] = self.create_mock_block(f"Block{i}", f"Description for block {i}")
|
||||
|
||||
return blocks
|
||||
|
||||
def test_basic_model_card_content_structure(self):
|
||||
"""Test that all expected keys are present in the output."""
|
||||
blocks = self.create_mock_blocks()
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
expected_keys = [
|
||||
"pipeline_name",
|
||||
"model_description",
|
||||
"blocks_description",
|
||||
"components_description",
|
||||
"configs_section",
|
||||
"inputs_description",
|
||||
"outputs_description",
|
||||
"trigger_inputs_section",
|
||||
"tags",
|
||||
]
|
||||
|
||||
for key in expected_keys:
|
||||
assert key in content, f"Expected key '{key}' not found in model card content"
|
||||
|
||||
assert isinstance(content["tags"], list), "Tags should be a list"
|
||||
|
||||
def test_pipeline_name_generation(self):
|
||||
"""Test that pipeline name is correctly generated from blocks class name."""
|
||||
blocks = self.create_mock_blocks(class_name="StableDiffusionBlocks")
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["pipeline_name"] == "StableDiffusion Pipeline"
|
||||
|
||||
def test_tags_generation_text_to_image(self):
|
||||
"""Test that text-to-image tags are correctly generated."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "modular-diffusers" in content["tags"]
|
||||
assert "diffusers" in content["tags"]
|
||||
assert "text-to-image" in content["tags"]
|
||||
|
||||
def test_tags_generation_with_trigger_inputs(self):
|
||||
"""Test that tags are correctly generated based on trigger inputs."""
|
||||
# Test inpainting
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["mask", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "inpainting" in content["tags"]
|
||||
|
||||
# Test image-to-image
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["image", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "image-to-image" in content["tags"]
|
||||
|
||||
# Test controlnet
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["control_image", "prompt"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
assert "controlnet" in content["tags"]
|
||||
|
||||
def test_tags_with_model_name(self):
|
||||
"""Test that model name is included in tags when present."""
|
||||
blocks = self.create_mock_blocks(model_name="stable-diffusion-xl")
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "stable-diffusion-xl" in content["tags"]
|
||||
|
||||
def test_components_description_formatting(self):
|
||||
"""Test that components are correctly formatted."""
|
||||
components = [
|
||||
ComponentSpec(name="vae", description="VAE component"),
|
||||
ComponentSpec(name="text_encoder", description="Text encoder component"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(components=components)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "vae" in content["components_description"]
|
||||
assert "text_encoder" in content["components_description"]
|
||||
# Should be enumerated
|
||||
assert "1." in content["components_description"]
|
||||
|
||||
def test_components_description_empty(self):
|
||||
"""Test handling of pipelines without components."""
|
||||
blocks = self.create_mock_blocks(components=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific components required" in content["components_description"]
|
||||
|
||||
def test_configs_section_with_configs(self):
|
||||
"""Test that configs section is generated when configs are present."""
|
||||
configs = [
|
||||
ConfigSpec(name="num_train_timesteps", default=1000, description="Number of training timesteps"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(configs=configs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "## Configuration Parameters" in content["configs_section"]
|
||||
|
||||
def test_configs_section_empty(self):
|
||||
"""Test that configs section is empty when no configs are present."""
|
||||
blocks = self.create_mock_blocks(configs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["configs_section"] == ""
|
||||
|
||||
def test_inputs_description_required_and_optional(self):
|
||||
"""Test that required and optional inputs are correctly formatted."""
|
||||
inputs = [
|
||||
InputParam(name="prompt", type_hint=str, required=True, description="The input prompt"),
|
||||
InputParam(name="num_steps", type_hint=int, required=False, default=50, description="Number of steps"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(inputs=inputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "**Required:**" in content["inputs_description"]
|
||||
assert "**Optional:**" in content["inputs_description"]
|
||||
assert "prompt" in content["inputs_description"]
|
||||
assert "num_steps" in content["inputs_description"]
|
||||
assert "default: `50`" in content["inputs_description"]
|
||||
|
||||
def test_inputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific inputs."""
|
||||
blocks = self.create_mock_blocks(inputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific inputs defined" in content["inputs_description"]
|
||||
|
||||
def test_outputs_description_formatting(self):
|
||||
"""Test that outputs are correctly formatted."""
|
||||
outputs = [
|
||||
OutputParam(name="images", type_hint=torch.Tensor, description="Generated images"),
|
||||
]
|
||||
blocks = self.create_mock_blocks(outputs=outputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "images" in content["outputs_description"]
|
||||
assert "Generated images" in content["outputs_description"]
|
||||
|
||||
def test_outputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific outputs."""
|
||||
blocks = self.create_mock_blocks(outputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "Standard pipeline outputs" in content["outputs_description"]
|
||||
|
||||
def test_trigger_inputs_section_with_triggers(self):
|
||||
"""Test that trigger inputs section is generated when present."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=["mask", "image"])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "### Conditional Execution" in content["trigger_inputs_section"]
|
||||
assert "`mask`" in content["trigger_inputs_section"]
|
||||
assert "`image`" in content["trigger_inputs_section"]
|
||||
|
||||
def test_trigger_inputs_section_empty(self):
|
||||
"""Test that trigger inputs section is empty when not present."""
|
||||
blocks = self.create_mock_blocks(trigger_inputs=None)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert content["trigger_inputs_section"] == ""
|
||||
|
||||
def test_blocks_description_with_sub_blocks(self):
|
||||
"""Test that blocks with sub-blocks are correctly described."""
|
||||
|
||||
class MockBlockWithSubBlocks:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = "ParentBlock"
|
||||
self.description = "Parent block"
|
||||
self.sub_blocks = {
|
||||
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
|
||||
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
|
||||
}
|
||||
|
||||
def create_child_block(self, name, desc):
|
||||
class ChildBlock:
|
||||
def __init__(self):
|
||||
self.__class__.__name__ = name
|
||||
self.description = desc
|
||||
|
||||
return ChildBlock()
|
||||
|
||||
blocks = self.create_mock_blocks()
|
||||
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
|
||||
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "parent" in content["blocks_description"]
|
||||
assert "child1" in content["blocks_description"]
|
||||
assert "child2" in content["blocks_description"]
|
||||
|
||||
def test_model_description_includes_block_count(self):
|
||||
"""Test that model description includes the number of blocks."""
|
||||
blocks = self.create_mock_blocks(num_blocks=5)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "5-block architecture" in content["model_description"]
|
||||
|
||||
Reference in New Issue
Block a user