Compare commits

...

55 Commits

Author SHA1 Message Date
sayakpaul
a8c2dd12d7 empty commit 2025-04-08 10:43:56 +05:30
github-actions[bot]
65a3bf5f81 Apply style fixes 2025-04-08 04:46:02 +00:00
Sayak Paul
ed33194780 Merge branch 'main' into auraflow-lora 2025-04-08 10:07:28 +05:30
Hameer Abbasi
62411093f8 Merge remote-tracking branch 'upstream/main' into auraflow-lora 2025-03-26 03:40:12 +01:00
Hameer Abbasi
dee9074b9e Fix AuraFlow LoRA tests by applying to the right denoiser layers.
Co-authored-by: AstraliteHeart <81396681+AstraliteHeart@users.noreply.github.com>
2025-03-26 03:39:43 +01:00
Hameer Abbasi
5091757022 Unrelated TE LoRA tests fix. 2025-03-06 07:15:19 +01:00
Hameer Abbasi
b19942f106 MPS fix for LoRA tests. 2025-03-06 07:09:01 +01:00
Hameer Abbasi
8aa2d6909d Fix TE target modules. 2025-03-06 07:02:44 +01:00
Hameer Abbasi
c2daa8a9bd Propagate attention. 2025-03-05 10:14:39 +01:00
Hameer Abbasi
75ba7dacd7 Remove repeated docs. 2025-03-05 10:13:57 +01:00
Hameer Abbasi
636f01cdf3 Remove unnecessary code. 2025-02-27 15:45:21 +01:00
Hameer Abbasi
c11b14d729 Remove duplicated code. 2025-02-27 15:28:07 +01:00
Hameer Abbasi
3b9e65543a Address review comments. 2025-02-26 11:47:16 +01:00
Hameer Abbasi
ce1939b374 Skip for MPS. 2025-02-26 10:59:07 +01:00
Hameer Abbasi
cdd184d44f Mirror Lumina2. 2025-02-26 10:49:17 +01:00
Hameer Abbasi
c602749050 Merge branch 'main' into auraflow-lora 2025-02-26 06:52:47 +01:00
Hameer Abbasi
12fbd11853 Remove unneeded stuff. 2025-01-19 15:48:20 +01:00
Hameer Abbasi
83e0825a74 Sync with upstream changes. 2025-01-19 15:09:12 +01:00
Hameer Abbasi
0fa5cd5d14 make fix-copies 2025-01-19 14:57:44 +01:00
Hameer Abbasi
cd691d3762 Undo support for TE in AuraFlow LoRA. 2025-01-19 14:56:00 +01:00
Hameer Abbasi
5620384df3 Undo adding LoRA support for non-CLIP TEs. 2025-01-19 14:55:14 +01:00
Hameer Abbasi
7e63330ef1 Restore LoRA tests. 2025-01-19 14:51:16 +01:00
Hameer Abbasi
df28362b78 Merge remote-tracking branch 'upstream/main' into auraflow-lora 2025-01-19 14:48:52 +01:00
Hameer Abbasi
3095644ebe Merge branch 'main' into auraflow-lora 2025-01-13 06:43:33 +01:00
hlky
077a452ec3 Merge branch 'main' into auraflow-lora 2025-01-10 07:22:56 +00:00
Hameer Abbasi
1ec07a1fe4 Merge remote-tracking branch 'upstream/main' into auraflow-lora 2025-01-10 08:20:28 +01:00
Hameer Abbasi
2b359094b6 Support LoRA for non-CLIP TEs. 2025-01-10 08:02:59 +01:00
Hameer Abbasi
e06d8eb94f Skip TE-only tests for AuraFlow. 2025-01-08 15:38:37 +01:00
hlky
0ea9ecd1ff Merge branch 'main' into auraflow-lora 2025-01-08 13:36:36 +00:00
Hameer Abbasi
532013f990 Remove useless property. 2025-01-07 05:56:36 +01:00
Hameer Abbasi
2b934b458a Make test better. 2025-01-07 05:16:41 +01:00
Hameer Abbasi
2d02c2c8d2 Add lora_scale property for te LoRAs. 2025-01-07 04:20:30 +01:00
Hameer Abbasi
6e899a385a Merge branch 'main' into auraflow-lora 2025-01-07 04:10:26 +01:00
Hameer Abbasi
d6028cdc53 joint_attention_kwargs->attention_kwargs 2025-01-07 04:09:35 +01:00
Hameer Abbasi
28a4918a47 Get more tests passing by copying from Flux. Address review comments. 2025-01-07 04:02:01 +01:00
Sayak Paul
6da81f8960 Merge branch 'main' into auraflow-lora 2025-01-06 22:23:06 +05:30
Hameer Abbasi
5700e52757 Merge branch 'main' into auraflow-lora 2025-01-03 11:35:00 +01:00
Hameer Abbasi
9454e845e6 Rebasing derp. 2024-12-19 10:25:49 +01:00
Hameer Abbasi
1c7909565a Address review comments. 2024-12-19 10:23:21 +01:00
Hameer Abbasi
bc2a4663fd Attept 3 to fix tests. 2024-12-19 09:58:19 +01:00
Hameer Abbasi
6b762b800c Attept 2 to fix tests. 2024-12-19 09:58:19 +01:00
Hameer Abbasi
2b36416192 Attept 1 to fix tests. 2024-12-19 09:58:19 +01:00
Hameer Abbasi
894eac0aa8 ruff check --fix 2024-12-19 09:58:19 +01:00
Hameer Abbasi
a73df6b949 make fix-copies 2024-12-19 09:58:19 +01:00
hlky
a242d7ae69 make style 2024-12-19 09:58:19 +01:00
Hameer Abbasi
875a3e0f90 Quality fixes. 2024-12-19 09:58:19 +01:00
hlky
1b7f99fb84 fix 2024-12-19 09:58:18 +01:00
hlky
c07d1f5185 fix 2024-12-19 09:58:18 +01:00
Hameer Abbasi
4e4f780e54 Rebasing derp. 2024-12-19 09:58:18 +01:00
Warlord-K
0eee03eefe Change attention_kwargs->joint_attention_kwargs 2024-12-19 09:58:18 +01:00
Warlord-K
71f8bace8b Add Suggested changes 2024-12-19 09:58:14 +01:00
Warlord-K
98b19f66aa Add AuraFlowLoraLoaderMixin to documentation 2024-12-19 09:57:49 +01:00
Warlord-K
4208d09601 Add Tests 2024-12-19 09:57:24 +01:00
Warlord-K
658d058686 Add comments, remove qkv fusion 2024-12-19 09:57:24 +01:00
Warlord-K
e50a108208 Add AuraFlowLoraLoaderMixin 2024-12-19 09:57:24 +01:00
9 changed files with 564 additions and 25 deletions

View File

@@ -20,6 +20,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`FluxLoraLoaderMixin`] provides similar functions for [Flux](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux).
- [`CogVideoXLoraLoaderMixin`] provides similar functions for [CogVideoX](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogvideox).
- [`Mochi1LoraLoaderMixin`] provides similar functions for [Mochi](https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi).
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
@@ -56,6 +57,9 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
## Mochi1LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
## AuraFlowLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
## LTXVideoLoraLoaderMixin

View File

@@ -65,6 +65,7 @@ if is_torch_available():
"AmusedLoraLoaderMixin",
"StableDiffusionLoraLoaderMixin",
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
@@ -103,6 +104,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
AuraFlowLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
CogView4LoraLoaderMixin,
FluxLoraLoaderMixin,

View File

@@ -1328,6 +1328,311 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class AuraFlowLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`AuraFlowTransformer2DModel`] Specific to [`AuraFlowPipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
Return state dict for lora weights and the network alphas.
<Tip warning={true}>
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
This function is experimental and might change in the future.
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
[`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers.
transformer (`AuraFlowTransformer2DModel`):
The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
Example:
```py
from diffusers import DiffusionPipeline
import torch
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
<Tip warning={true}>
This is an experimental API.
</Tip>
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components)
class FluxLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`FluxTransformer2DModel`],

View File

@@ -52,6 +52,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,

View File

@@ -38,7 +38,7 @@ if is_transformers_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def text_encoder_attn_modules(text_encoder):
def text_encoder_attn_modules(text_encoder: nn.Module):
attn_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
return attn_modules
def text_encoder_mlp_modules(text_encoder):
def text_encoder_mlp_modules(text_encoder: nn.Module):
mlp_modules = []
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):

View File

@@ -13,15 +13,15 @@
# limitations under the License.
from typing import Dict, Union
from typing import Any, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import logging
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_processor import (
Attention,
@@ -160,14 +160,20 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
# Attention.
attn_output = self.attn(hidden_states=norm_hidden_states)
attn_output = self.attn(hidden_states=norm_hidden_states, **attention_kwargs)
# Process attention outputs for the `hidden_states`.
hidden_states = self.norm2(residual + gate_msa.unsqueeze(1) * attn_output)
@@ -223,10 +229,15 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
# Norm + Projection.
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
@@ -236,7 +247,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -254,7 +267,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
r"""
A 2D Transformer model as introduced in AuraFlow (https://blog.fal.ai/auraflow/).
@@ -449,8 +462,24 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
height, width = hidden_states.shape[-2:]
# Apply patch embedding, timestep embedding, and project the caption embeddings.
@@ -474,7 +503,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_kwargs=attention_kwargs,
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -491,7 +523,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
)
else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb)
combined_hidden_states = block(
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
)
hidden_states = combined_hidden_states[:, encoder_seq_len:]
@@ -512,6 +546,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
shape=(hidden_states.shape[0], out_channels, height * patch_size, width * patch_size)
)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -12,17 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import AuraFlowLoraLoaderMixin
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -112,7 +120,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class AuraFlowPipeline(DiffusionPipeline):
class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
r"""
Args:
tokenizer (`T5TokenizerFast`):
@@ -233,6 +241,7 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
lora_scale: Optional[float] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
@@ -259,10 +268,20 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, AuraFlowLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
@@ -346,6 +365,11 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds = None
negative_prompt_attention_mask = None
if self.text_encoder is not None:
if isinstance(self, AuraFlowLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
@@ -403,6 +427,10 @@ class AuraFlowPipeline(DiffusionPipeline):
def guidance_scale(self):
return self._guidance_scale
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@@ -428,6 +456,7 @@ class AuraFlowPipeline(DiffusionPipeline):
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
@@ -486,6 +515,10 @@ class AuraFlowPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -520,6 +553,7 @@ class AuraFlowPipeline(DiffusionPipeline):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
# 2. Determine batch size.
if prompt is not None and isinstance(prompt, str):
@@ -530,6 +564,7 @@ class AuraFlowPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -553,6 +588,7 @@ class AuraFlowPipeline(DiffusionPipeline):
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -594,6 +630,7 @@ class AuraFlowPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
return_dict=False,
attention_kwargs=self.attention_kwargs,
)[0]
# perform guidance

View File

@@ -0,0 +1,136 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.utils.testing_utils import (
floats_tensor,
is_peft_available,
require_peft_backend,
)
if is_peft_available():
pass
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
"sample_size": 64,
"patch_size": 1,
"in_channels": 4,
"num_mmdit_layers": 1,
"num_single_dit_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"caption_projection_dim": 32,
"pos_embed_max_size": 64,
}
transformer_cls = AuraFlowTransformer2DModel
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 4,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = UMT5EncoderModel, "hf-internal-testing/tiny-random-umt5"
text_encoder_target_modules = ["q", "k", "v", "o"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0", "linear_1"]
@property
def output_shape(self):
return (1, 8, 8, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 4,
"guidance_scale": 0.0,
"height": 8,
"width": 8,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -104,6 +104,7 @@ class PeftLoraLoaderMixinTests:
vae_kwargs = None
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs and self.transformer_kwargs:
@@ -157,7 +158,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
)
@@ -641,9 +642,9 @@ class PeftLoraLoaderMixinTests:
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
lora_alpha=4,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=False,
)
@@ -2090,7 +2091,7 @@ class PeftLoraLoaderMixinTests:
bias_values = {}
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
for name, module in denoiser.named_modules():
if any(k in name for k in ["to_q", "to_k", "to_v", "to_out.0"]):
if any(k in name for k in self.denoiser_target_modules):
if module.bias is not None:
bias_values[name] = module.bias.data.clone()
@@ -2187,6 +2188,17 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
@property
def supports_text_encoder_lora(self):
return (
len(
{"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(
self.pipeline_class._lora_loadable_modules
)
)
!= 0
)
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
@@ -2239,11 +2251,15 @@ class PeftLoraLoaderMixinTests:
pipe_fp32 = initialize_pipeline(storage_dtype=None)
pipe_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
# MPS doesn't support float8 yet.
if torch_device not in {"mps"}:
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.14.0")
def test_layerwise_casting_peft_input_autocast_denoiser(self):
@@ -2268,7 +2284,7 @@ class PeftLoraLoaderMixinTests:
apply_layerwise_casting,
)
storage_dtype = torch.float8_e4m3fn
storage_dtype = torch.float8_e4m3fn if not torch_device == "mps" else torch.bfloat16
compute_dtype = torch.float32
def check_module(denoiser):