Compare commits

..

31 Commits

Author SHA1 Message Date
Sayak Paul
db250958c5 Merge branch 'main' into migrate-lora-pytest 2025-12-10 12:36:41 +08:00
sayakpaul
f956ba0db1 resolve conflicts. 2025-12-04 20:07:15 +08:00
sayakpaul
f3593a8aa9 up 2025-12-03 17:42:36 +08:00
Sayak Paul
1b6cdea043 Merge branch 'main' into migrate-lora-pytest 2025-12-03 17:39:35 +08:00
Sayak Paul
3fb66f23ac Merge branch 'main' into migrate-lora-pytest 2025-11-20 10:13:01 +05:30
sayakpaul
9c3bed1783 up 2025-11-20 10:12:31 +05:30
Sayak Paul
11b80d09b0 Merge branch 'main' into migrate-lora-pytest 2025-11-10 13:27:10 +05:30
Sayak Paul
9201505554 Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:39:44 +05:30
sayakpaul
eece7120dd up 2025-11-06 10:31:37 +05:30
Sayak Paul
2e42205c3a Merge branch 'main' into migrate-lora-pytest 2025-11-06 10:24:51 +05:30
Sayak Paul
757bbf7b05 Merge branch 'main' into migrate-lora-pytest 2025-10-24 22:24:15 +05:30
Sayak Paul
4561c065aa Merge branch 'main' into migrate-lora-pytest 2025-10-17 19:29:40 +05:30
Sayak Paul
4ae5772fef Merge branch 'main' into migrate-lora-pytest 2025-10-17 07:55:31 +05:30
sayakpaul
0d3da485a0 up 2025-10-03 21:00:05 +05:30
sayakpaul
4f5e9a665e up 2025-10-03 20:49:50 +05:30
Sayak Paul
23e5559c54 Merge branch 'main' into migrate-lora-pytest 2025-10-03 20:44:52 +05:30
sayakpaul
f8f27891c6 up 2025-10-03 20:14:45 +05:30
sayakpaul
128535cfcd up 2025-10-03 20:03:50 +05:30
sayakpaul
bdc9537999 more fixtures. 2025-10-03 20:01:26 +05:30
sayakpaul
dae161ed26 up 2025-10-03 17:39:55 +05:30
sayakpaul
c4bcf72084 up 2025-10-03 16:56:31 +05:30
sayakpaul
1737b710a2 up 2025-10-03 16:45:04 +05:30
sayakpaul
565d674cc4 change flux lora integration tests to use pytest 2025-10-03 16:30:58 +05:30
sayakpaul
610842af1a up 2025-10-03 16:14:36 +05:30
sayakpaul
cba82591e8 up 2025-10-03 15:56:37 +05:30
sayakpaul
949cc1c326 up 2025-10-03 14:54:23 +05:30
sayakpaul
ec866f5de8 tempfile is now a fixture. 2025-10-03 14:25:54 +05:30
sayakpaul
7b4bcce602 up 2025-10-03 14:10:31 +05:30
sayakpaul
d61bb38fb4 up 2025-10-03 13:14:05 +05:30
sayakpaul
9e92f6bb63 up 2025-10-03 12:53:37 +05:30
sayakpaul
6c6cade1a7 migrate lora pipeline tests to pytest 2025-10-03 12:52:56 +05:30
34 changed files with 1283 additions and 3729 deletions

View File

@@ -237,8 +237,6 @@ By selectively loading and unloading the models you need at a given stage and sh
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
### Ring Attention
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -247,60 +245,40 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
```py
import torch
from torch import distributed as dist
from diffusers import DiffusionPipeline, ContextParallelConfig
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
try:
torch.distributed.init_process_group("nccl")
rank = torch.distributed.get_rank()
device = torch.device("cuda", rank % torch.cuda.device_count())
torch.cuda.set_device(device)
return device
def main():
device = setup_distributed()
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
pipeline.transformer.enable_parallelism(config=cp_config)
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
pipeline.transformer.set_attention_backend("flash")
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
# Must specify generator so all ranks start with same latents (or pass your own)
generator = torch.Generator().manual_seed(42)
image = pipeline(
prompt,
guidance_scale=3.5,
num_inference_steps=50,
generator=generator,
).images[0]
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
if rank == 0:
image.save("output.png")
if dist.get_rank() == 0:
image.save(f"output.png")
except Exception as e:
print(f"An error occurred: {e}")
torch.distributed.breakpoint()
raise
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
```
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
/```shell
`torchrun --nproc-per-node 2 above_script.py`.
/```
### Ulysses Attention
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -310,26 +288,5 @@ The script above needs to be run with a distributed launcher, such as [torchrun]
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
```py
# Depending on the number of GPUs available.
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.
```py
CKPT_ID = "black-forest-labs/FLUX.1-dev"
cp_config = ContextParallelConfig(ring_degree=2)
transformer = AutoModel.from_pretrained(
CKPT_ID,
subfolder="transformer",
torch_dtype=torch.bfloat16,
parallel_config=cp_config
)
pipeline = DiffusionPipeline.from_pretrained(
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
).to(device)
```

View File

@@ -404,8 +404,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
@@ -1113,8 +1111,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,

View File

@@ -1487,7 +1487,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Load LoRA layers into [`FluxTransformer2DModel`],
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
Specific to [`FluxPipeline`].
Specific to [`StableDiffusion3Pipeline`].
"""
_lora_loadable_modules = ["transformer", "text_encoder"]
@@ -1628,7 +1628,30 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
loaded.
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
dict is loaded into `self.transformer`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3628,17 +3651,44 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
@classmethod
@validate_hf_hub_args
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
Return state dict for lora weights and the network alphas.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted on the Hub.
- A path to a *directory* containing the model weights.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
use_safetensors (`bool`, *optional*):
Whether to use safetensors for loading.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
# Load the main state dict first which has the LoRA layers
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -3681,7 +3731,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -3690,13 +3739,26 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
hotswap (`bool`, *optional*):
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
@@ -3713,6 +3775,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
# Load LoRA into transformer
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
@@ -3724,7 +3787,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
def load_lora_into_transformer(
cls,
state_dict,
@@ -3736,9 +3798,23 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
Load the LoRA layers specified in `state_dict` into `transformer`.
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters.
transformer (`Kandinsky5Transformer3DModel`):
The transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights.
hotswap (`bool`, *optional*):
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
@@ -3756,7 +3832,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
@@ -3765,10 +3840,24 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
transformer_lora_adapter_metadata=None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
Save the LoRA parameters corresponding to the transformer and text encoders.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to.
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `transformer`.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process.
save_function (`Callable`):
The function to use to save the state dictionary.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer.
"""
lora_layers = {}
lora_metadata = {}
@@ -3778,7 +3867,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
raise ValueError("You must pass at least one of `transformer_lora_layers`")
cls._save_lora_weights(
save_directory=save_directory,
@@ -3790,7 +3879,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -3800,7 +3888,25 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
safe_fusing (`bool`, defaults to `False`):
Whether to check fused weights for NaN values before fusing.
adapter_names (`List[str]`, *optional*):
Adapter names to be used for fusing.
Example:
```py
from diffusers import Kandinsky5T2VPipeline
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
pipeline.load_lora_weights("path/to/lora.safetensors")
pipeline.fuse_lora(lora_scale=0.7)
```
"""
super().fuse_lora(
components=components,
@@ -3810,10 +3916,12 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
"""
super().unfuse_lora(components=components, **kwargs)

View File

@@ -52,10 +52,6 @@ else:
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["flux2"] = [
"Flux2AutoBlocks",
"Flux2ModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
@@ -79,7 +75,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,

View File

@@ -1,111 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = [
"Flux2TextEncoderStep",
"Flux2RemoteTextEncoderStep",
"Flux2VaeEncoderStep",
]
_import_structure["before_denoise"] = [
"Flux2SetTimestepsStep",
"Flux2PrepareLatentsStep",
"Flux2RoPEInputsStep",
"Flux2PrepareImageLatentsStep",
]
_import_structure["denoise"] = [
"Flux2LoopDenoiser",
"Flux2LoopAfterDenoiser",
"Flux2DenoiseLoopWrapper",
"Flux2DenoiseStep",
]
_import_structure["decoders"] = ["Flux2DecodeStep"]
_import_structure["inputs"] = [
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2BeforeDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import (
Flux2DenoiseLoopWrapper,
Flux2DenoiseStep,
Flux2LoopAfterDenoiser,
Flux2LoopDenoiser,
)
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
REMOTE_AUTO_BLOCKS,
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2BeforeDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_pipeline import Flux2ModularPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -1,508 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import numpy as np
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
"""Compute empirical mu for Flux2 timestep scheduling."""
a1, b1 = 8.73809524e-05, 1.89833333
a2, b2 = 0.00016927, 0.45666666
if image_seq_len > 4300:
mu = a2 * image_seq_len + b2
return float(mu)
m_200 = a2 * image_seq_len + b2
m_10 = a1 * image_seq_len + b1
a = (m_200 - m_10) / 190.0
b = m_200 - 200.0 * a
mu = a * num_steps + b
return float(mu)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class Flux2SetTimestepsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def description(self) -> str:
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50),
InputParam("timesteps"),
InputParam("sigmas"),
InputParam("guidance_scale", default=4.0),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.device = components._execution_device
scheduler = components.scheduler
height = block_state.height or components.default_height
width = block_state.width or components.default_width
vae_scale_factor = components.vae_scale_factor
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
num_inference_steps = block_state.num_inference_steps
sigmas = block_state.sigmas
timesteps = block_state.timesteps
if timesteps is None and sigmas is None:
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
sigmas = None
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
block_state.device,
timesteps=timesteps,
sigmas=sigmas,
mu=mu,
)
block_state.timesteps = timesteps
block_state.num_inference_steps = num_inference_steps
batch_size = block_state.batch_size * block_state.num_images_per_prompt
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
guidance = guidance.expand(batch_size)
block_state.guidance = guidance
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def description(self) -> str:
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height", type_hint=int),
InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1),
InputParam("generator"),
InputParam(
"batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
),
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
]
@staticmethod
def check_inputs(components, block_state):
vae_scale_factor = components.vae_scale_factor
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
)
@staticmethod
def _prepare_latent_ids(latents: torch.Tensor):
"""
Generates 4D position coordinates (T, H, W, L) for latent tensors.
Args:
latents: Latent tensor of shape (B, C, H, W)
Returns:
Position IDs tensor of shape (B, H*W, 4)
"""
batch_size, _, height, width = latents.shape
t = torch.arange(1)
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1)
latent_ids = torch.cartesian_prod(t, h, w, l)
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
return latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@staticmethod
def prepare_latents(
comp,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
latents = self.prepare_latents(
components,
batch_size,
block_state.num_channels_latents,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
block_state.latents,
)
latent_ids = self._prepare_latent_ids(latents)
latent_ids = latent_ids.to(block_state.device)
latents = self._pack_latents(latents)
block_state.latents = latents
block_state.latent_ids = latent_ids
self.set_block_state(state, block_state)
return components, state
class Flux2RoPEInputsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="prompt_embeds", required=True),
InputParam(name="latent_ids"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
),
OutputParam(
name="latent_ids",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
),
]
@staticmethod
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
"""Prepare 4D position IDs for text tokens."""
B, L, _ = x.shape
out_ids = []
for i in range(B):
t = torch.arange(1) if t_coord is None else t_coord[i]
h = torch.arange(1)
w = torch.arange(1)
seq_l = torch.arange(L)
coords = torch.cartesian_prod(t, h, w, seq_l)
out_ids.append(coords)
return torch.stack(out_ids)
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
prompt_embeds = block_state.prompt_embeds
device = prompt_embeds.device
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
block_state.txt_ids = block_state.txt_ids.to(device)
self.set_block_state(state, block_state)
return components, state
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image_latents", type_hint=List[torch.Tensor]),
InputParam("batch_size", required=True, type_hint=int),
InputParam("num_images_per_prompt", default=1, type_hint=int),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning",
),
OutputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents",
),
]
@staticmethod
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
"""
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
Args:
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
scale: Factor used to define the time separation between latents.
Returns:
Combined coordinate tensor of shape (1, N_total, 4)
"""
if not isinstance(image_latents, list):
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
t_coords = [t.view(-1) for t in t_coords]
image_latent_ids = []
for x, t in zip(image_latents, t_coords):
x = x.squeeze(0)
_, height, width = x.shape
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
image_latent_ids.append(x_ids)
image_latent_ids = torch.cat(image_latent_ids, dim=0)
image_latent_ids = image_latent_ids.unsqueeze(0)
return image_latent_ids
@staticmethod
def _pack_latents(latents):
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
batch_size, num_channels, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
return latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
image_latents = block_state.image_latents
if image_latents is None:
block_state.image_latents = None
block_state.image_latent_ids = None
self.set_block_state(state, block_state)
return components, state
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
image_latent_ids = self._prepare_image_ids(image_latents)
packed_latents = []
for latent in image_latents:
packed = self._pack_latents(latent)
packed = packed.squeeze(0)
packed_latents.append(packed)
image_latents = torch.cat(packed_latents, dim=0)
image_latents = image_latents.unsqueeze(0)
image_latents = image_latents.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
image_latent_ids = image_latent_ids.to(device)
block_state.image_latents = image_latents
block_state.image_latent_ids = image_latent_ids
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,146 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple, Union
import numpy as np
import PIL
import torch
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLFlux2
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2DecodeStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLFlux2),
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The denoised latents from the denoising step",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="Position IDs for the latents, used for unpacking",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"images",
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
)
]
@staticmethod
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
"""
Unpack latents using position IDs to scatter tokens into place.
Args:
x: Packed latents tensor of shape (B, seq_len, C)
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
Returns:
Unpacked latents tensor of shape (B, C, H, W)
"""
x_list = []
for data, pos in zip(x, x_ids):
_, ch = data.shape # noqa: F841
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
out = out.view(h, w, ch).permute(2, 0, 1)
x_list.append(out)
return torch.stack(x_list, dim=0)
@staticmethod
def _unpatchify_latents(latents):
"""Convert patchified latents back to regular format."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
latents = latents.permute(0, 1, 4, 2, 5, 3)
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
return latents
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae = components.vae
if block_state.output_type == "latent":
block_state.images = block_state.latents
else:
latents = block_state.latents
latent_ids = block_state.latent_ids
latents = self._unpack_latents_with_ids(latents, latent_ids)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
latents.device, latents.dtype
)
latents = latents * latents_bn_std + latents_bn_mean
latents = self._unpatchify_latents(latents)
block_state.images = vae.decode(latents, return_dict=False)[0]
block_state.images = components.image_processor.postprocess(
block_state.images, output_type=block_state.output_type
)
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,252 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Tuple
import torch
from ...models import Flux2Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2LoopDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoises the latents for Flux2. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The latents to denoise. Shape: (B, seq_len, C)",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
),
InputParam(
"image_latent_ids",
type_hint=torch.Tensor,
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings from Mistral3",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for text tokens (T, H, W, L)",
),
InputParam(
"latent_ids",
required=True,
type_hint=torch.Tensor,
description="4D position IDs for latent tokens (T, H, W, L)",
),
]
@torch.no_grad()
def __call__(
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents.to(components.transformer.dtype)
img_ids = block_state.latent_ids
image_latents = getattr(block_state, "image_latents", None)
if image_latents is not None:
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
image_latent_ids = block_state.image_latent_ids
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
txt_ids=block_state.txt_ids,
img_ids=img_ids,
joint_attention_kwargs=block_state.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred
return components, block_state
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux2"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
@property
def description(self) -> str:
return (
"Step within the denoising loop that updates the latents after denoising. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `Flux2DenoiseLoopWrapper`)"
)
@property
def inputs(self) -> List[Tuple[str, Any]]:
return []
@property
def intermediate_inputs(self) -> List[str]:
return [InputParam("generator")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred,
t,
block_state.latents,
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoises the latents over `timesteps`. "
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
)
@property
def loop_expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
ComponentSpec("transformer", Flux2Transformer2DModel),
]
@property
def loop_inputs(self) -> List[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process.",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self.set_block_state(state, block_state)
return components, state
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoises the latents for Flux2. \n"
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `Flux2LoopDenoiser`\n"
" - `Flux2LoopAfterDenoiser`\n"
"This block supports both text-to-image and image-conditioned generation."
)

View File

@@ -1,349 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
from ...models import AutoencoderKLFlux2
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def format_text_input(prompts: List[str], system_message: str = None):
"""Format prompts for Mistral3 chat template."""
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
return [
[
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{"role": "user", "content": [{"type": "text", "text": prompt}]},
]
for prompt in cleaned_txt
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class Flux2TextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
# fmt: off
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
# fmt: on
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
ComponentSpec("tokenizer", AutoProcessor),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from Mistral3 used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@staticmethod
def _get_mistral_3_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
tokenizer: AutoProcessor,
prompt: Union[str, List[str]],
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
max_sequence_length: int = 512,
# fmt: off
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
# fmt: on
hidden_states_layers: Tuple[int] = (10, 20, 30),
):
dtype = text_encoder.dtype if dtype is None else dtype
device = text_encoder.device if device is None else device
prompt = [prompt] if isinstance(prompt, str) else prompt
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
inputs = tokenizer.apply_chat_template(
messages_batch,
add_generation_prompt=False,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_sequence_length,
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
output = text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
use_cache=False,
)
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
out = out.to(dtype=dtype, device=device)
batch_size, num_channels, seq_len, hidden_dim = out.shape
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
return prompt_embeds
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
text_encoder=components.text_encoder,
tokenizer=components.tokenizer,
prompt=prompt,
device=block_state.device,
max_sequence_length=block_state.max_sequence_length,
system_message=self.DEFAULT_SYSTEM_MESSAGE,
hidden_states_layers=block_state.text_encoder_out_layers,
)
self.set_block_state(state, block_state)
return components, state
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
@property
def description(self) -> str:
return "Text Encoder step that generates text embeddings using a remote API endpoint"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Text embeddings from remote API used to guide the image generation",
),
]
@staticmethod
def check_inputs(block_state):
prompt = block_state.prompt
prompt_embeds = getattr(block_state, "prompt_embeds", None)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
"Please make sure to only forward one of the two."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
import io
import requests
from huggingface_hub import get_token
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
if block_state.prompt_embeds is not None:
self.set_block_state(state, block_state)
return components, state
prompt = block_state.prompt
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt
response = requests.post(
self.REMOTE_URL,
json={"prompt": prompt},
headers={
"Authorization": f"Bearer {get_token()}",
"Content-Type": "application/json",
},
)
response.raise_for_status()
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
self.set_block_state(state, block_state)
return components, state
class Flux2VaeEncoderStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("vae", AutoencoderKLFlux2)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("condition_images", type_hint=List[torch.Tensor]),
InputParam("generator"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"image_latents",
type_hint=List[torch.Tensor],
description="List of latent representations for each reference image",
),
]
@staticmethod
def _patchify_latents(latents):
"""Convert latents to patchified format for Flux2."""
batch_size, num_channels_latents, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 1, 3, 5, 2, 4)
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
return latents
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
"""Encode a single image using Flux2 VAE with batch norm normalization."""
if image.ndim != 4:
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
image_latents = self._patchify_latents(image_latents)
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
return image_latents
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
condition_images = block_state.condition_images
if condition_images is None:
return components, state
device = components._execution_device
dtype = components.vae.dtype
image_latents = []
for image in condition_images:
image = image.to(device=device, dtype=dtype)
latent = self._encode_vae_image(
vae=components.vae,
image=image,
generator=block_state.generator,
)
image_latents.append(latent)
block_state.image_latents = image_latents
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,160 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import torch
from ...configuration_utils import FrozenDict
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import Flux2ModularPipeline
logger = logging.get_logger(__name__)
class Flux2TextInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return (
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_images_per_prompt", default=1),
InputParam(
"prompt_embeds",
required=True,
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide the image generation",
),
]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
self.set_block_state(state, block_state)
return components, state
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux2"
@property
def description(self) -> str:
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
Flux2ImageProcessor,
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image"),
InputParam("height"),
InputParam("width"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
@torch.no_grad()
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
block_state = self.get_block_state(state)
images = block_state.image
if images is None:
block_state.condition_images = None
self.set_block_state(state, block_state)
return components, state
if not isinstance(images, list):
images = [images]
condition_images = []
for img in images:
components.image_processor.check_image_input(img)
image_width, image_height = img.size
if image_width * image_height > 1024 * 1024:
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
image_width, image_height = img.size
multiple_of = components.vae_scale_factor * 2
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
condition_img = components.image_processor.preprocess(
img, height=image_height, width=image_width, resize_mode="crop"
)
condition_images.append(condition_img)
if block_state.height is None:
block_state.height = image_height
if block_state.width is None:
block_state.width = image_width
block_state.condition_images = condition_images
self.set_block_state(state, block_state)
return components, state

View File

@@ -1,166 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
]
)
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2VaeEncoderBlocks.values()
block_names = Flux2VaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [Flux2VaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
Flux2BeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
]
)
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2BeforeDenoiseBlocks.values()
block_names = Flux2BeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("text_input", Flux2TextInputStep()),
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
("before_denoise", Flux2BeforeDenoiseStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
class Flux2AutoBlocks(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
"- For text-to-image generation, all you need to provide is `prompt`.\n"
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("preprocess_images", Flux2ProcessImagesInputStep()),
("vae_encoder", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
"auto": AUTO_BLOCKS,
"remote": REMOTE_AUTO_BLOCKS,
}

View File

@@ -1,57 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...loaders import Flux2LoraLoaderMixin
from ...utils import logging
from ..modular_pipeline import ModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
"""
A ModularPipeline for Flux2.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "Flux2AutoBlocks"
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_width(self):
return self.default_sample_size * self.vae_scale_factor
@property
def default_sample_size(self):
return 128
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if getattr(self, "vae", None) is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_latents(self):
num_channels_latents = 32
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents

View File

@@ -58,7 +58,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
("flux-kontext", "FluxKontextModularPipeline"),
("flux2", "Flux2ModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
@@ -1587,6 +1586,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property

View File

@@ -2,36 +2,6 @@
from ..utils import DummyObject, requires_backends
class Flux2AutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Flux2ModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class FluxAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -13,16 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
from ..testing_utils import (
floats_tensor,
@@ -40,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -103,34 +99,34 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in AuraFlow.")
@pytest.mark.skip("Not supported in AuraFlow.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -13,10 +13,9 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
@@ -39,7 +38,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
@@ -119,54 +118,59 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False)],
)
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -13,12 +13,9 @@
# limitations under the License.
import sys
import tempfile
import unittest
import numpy as np
import pytest
import torch
from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
@@ -28,7 +25,6 @@ from ..testing_utils import (
require_peft_backend,
require_torch_accelerator,
skip_mps,
torch_device,
)
@@ -47,7 +43,7 @@ class TokenizerWrapper:
@require_peft_backend
@skip_mps
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -113,72 +109,50 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_save_pretrained(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@pytest.mark.parametrize(
"offload_type, use_stream",
[("block_level", True), ("leaf_level", False)],
)
@require_torch_accelerator
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in CogView4.")
@pytest.mark.skip("Not supported in CogView4.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -16,13 +16,11 @@ import copy
import gc
import os
import sys
import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
import torch
from parameterized import parameterized
from PIL import Image
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -46,14 +44,12 @@ from ..testing_utils import (
if is_peft_available():
from peft.utils import get_peft_model_state_dict
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestFluxLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -115,165 +111,134 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_with_alpha_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
# only do for `transformer` and for the k projections -- should be enough to test.
if "transformer" in k and "to_k" in k and "lora_A" in k:
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
# modify the state dict to have alpha values following
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
state_dict_with_alpha = safetensors.torch.load_file(
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
)
alpha_dict = {}
for k, v in state_dict_with_alpha.items():
if "transformer" in k and "to_k" in k and ("lora_A" in k):
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
state_dict_with_alpha.update(alpha_dict)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict_with_alpha)
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
"Loading from saved checkpoints should give same results."
)
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
def test_lora_expansion_works_for_absent_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
"Different LoRAs should lead to different results."
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
def test_lora_expansion_works_for_extra_keys(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
modified_denoiser_lora_config.target_modules.add("x_embedder")
pipe.transformer.add_adapter(modified_denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
with tempfile.TemporaryDirectory() as tmpdirname:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
# Modify the state dict to exclude "x_embedder" related LoRA params.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
# Load state dict with `x_embedder`.
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.unload_lora_weights()
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
pipe.set_adapters(["one", "two"])
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"Different LoRAs should lead to different results.",
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
"Different LoRAs should lead to different results."
)
self.assertFalse(
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
"LoRA should lead to different results.",
assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
"LoRA should lead to different results."
)
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -338,12 +303,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_with_norm_in_state_dict(self):
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_with_norm_in_state_dict(self, pipe):
_, _, inputs = self.get_dummy_inputs(with_generator=False)
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
@@ -364,39 +324,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.load_lora_weights(norm_state_dict)
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
assert (
"The provided state dict contains normalization layers in addition to LoRA layers"
in cap_logger.out
)
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
assert len(pipe.transformer._transformer_norm_layers) > 0
pipe.unload_lora_weights()
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
self.assertFalse(
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
assert pipe.transformer._transformer_norm_layers is None
assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
f"{norm_layer} is tested"
)
with CaptureLogger(logger) as cap_logger:
for key in list(norm_state_dict.keys()):
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
pipe.load_lora_weights(norm_state_dict)
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
self.assertTrue(
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
)
def test_lora_parameter_expanded_shapes(self):
def test_lora_parameter_expanded_shapes(self, pipe):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
@@ -405,24 +358,21 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
original_transformer_state_dict = pipe.transformer.state_dict()
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
self.assertTrue(
"x_embedder.weight" in incompatible_keys.missing_keys,
"Could not find x_embedder.weight in the missing keys.",
assert "x_embedder.weight" in incompatible_keys.missing_keys, (
"Could not find x_embedder.weight in the missing keys."
)
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
pipe.transformer = transformer
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
@@ -431,15 +381,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
# Testing opposite direction where the LoRA params are zero-padded.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -454,15 +402,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
def test_normal_lora_with_expanded_lora_raises_error(self):
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
@@ -494,32 +440,28 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
_, _, inputs = self.get_dummy_inputs(with_generator=False)
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert pipe.get_active_adapters() == ["adapter-1"]
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
(_, _, inputs) = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
assert pipe.get_active_adapters() == ["adapter-2"]
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
# This should raise a runtime error on input shapes being incompatible.
@@ -540,32 +482,24 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
assert pipe.transformer.config.in_channels == in_features
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
# We should check for input shapes being incompatible here. But because above mentioned issue is
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
# mismatch error.
self.assertRaisesRegex(
RuntimeError,
"size mismatch for x_embedder.lora_A.adapter-2.weight",
pipe.load_lora_weights,
lora_state_dict,
"adapter-2",
)
with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
pipe.load_lora_weights(lora_state_dict, "adapter-2")
def test_fuse_expanded_lora_with_regular_lora(self):
# This test checks if it works when a lora with expanded shapes (like control loras) but
@@ -597,7 +531,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
}
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
_, _, inputs = self.get_dummy_inputs(with_generator=False)
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -610,54 +544,44 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
pipe.load_lora_weights(lora_state_dict, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
def test_load_regular_lora(self):
def test_load_regular_lora(self, base_pipe_output, pipe):
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
# transformers include Flux Fill, Flux Control, etc.
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
in_features = in_features // 2
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
lora_state_dict = {
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
}
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
with CaptureLogger(logger) as cap_logger:
pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
def test_lora_unload_with_parameter_expanded_shapes(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -670,9 +594,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -697,33 +620,31 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
self.assertTrue(
control_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
)
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
self.assertTrue(
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
)
inputs.pop("control_image")
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
self.assertTrue(pipe.transformer.config.in_channels == in_features)
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
assert pipe.transformer.config.in_channels == in_features
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
@@ -731,14 +652,12 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.DEBUG)
# Change the transformer config to mimic a real use case.
num_channels_without_control = 4
transformer = FluxTransformer2DModel.from_config(
components["transformer"].config, in_channels=num_channels_without_control
).to(torch_device)
self.assertTrue(
transformer.config.in_channels == num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
assert transformer.config.in_channels == num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
)
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
@@ -763,40 +682,38 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
}
with CaptureLogger(logger) as cap_logger:
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
inputs["control_image"] = control_image
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
assert pipe.transformer.config.in_channels == 2 * in_features
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
self.assertTrue(
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
)
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
assert pipe.transformer.config.in_channels == in_features * 2
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
@pytest.mark.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@@ -806,7 +723,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
class TestFluxLoRAIntegration:
"""internal note: The integration slices were obtained on audace.
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
@@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
def tearDown(self):
super().tearDown()
del self.pipeline
gc.collect()
backend_empty_cache(torch_device)
def test_flux_the_last_ben(self):
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
self.pipeline = self.pipeline.to(torch_device)
def test_flux_the_last_ben(self, pipeline):
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "jon snow eating pizza with ketchup"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.0,
@@ -851,71 +762,57 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya(self):
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya(self, pipeline):
pipeline.load_lora_weights("Norod78/brain-slug-flux")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "The cat with a brain slug earring"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya_with_text_encoder(self):
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
def test_flux_kohya_with_text_encoder(self, pipeline):
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "optimus is cleaning the house with broomstick"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=4.5,
output_type="np",
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_kohya_embedders_conversion(self):
def test_flux_kohya_embedders_conversion(self, pipeline):
"""Test that embedders load without throwing errors"""
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
self.pipeline.unload_lora_weights()
assert True
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline = self.pipeline.to(torch_device)
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
pipeline.unload_lora_weights()
def test_flux_xlabs(self, pipeline):
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline = pipeline.to(torch_device)
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -923,23 +820,17 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
generator=torch.manual_seed(self.seed),
).images
out_slice = out[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 0.001
assert max_diff < 1e-3
def test_flux_xlabs_load_lora_with_single_blocks(self):
self.pipeline.load_lora_weights(
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.enable_model_cpu_offload()
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.enable_model_cpu_offload()
prompt = "a wizard mouse playing chess"
out = self.pipeline(
out = pipeline(
prompt,
num_inference_steps=self.num_inference_steps,
guidance_scale=3.5,
@@ -951,40 +842,43 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
assert max_diff < 1e-3
assert max_diff < 0.001
@nightly
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
class TestFluxControlLoRAIntegration:
num_inference_steps = 10
seed = 0
prompt = "A robot made of exotic candies and chocolates of different kinds."
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
self.pipeline = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(
@@ -995,7 +889,7 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
)
image = self.pipeline(
image = pipeline(
prompt=self.prompt,
control_image=control_image,
height=1024,
@@ -1016,12 +910,18 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
def test_lora_with_turbo(self, lora_ckpt_id):
self.pipeline.load_lora_weights(lora_ckpt_id)
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
@pytest.mark.parametrize(
"lora_ckpt_id",
[
"black-forest-labs/FLUX.1-Canny-dev-lora",
"black-forest-labs/FLUX.1-Depth-dev-lora",
],
)
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
pipeline.load_lora_weights(lora_ckpt_id)
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
pipeline.fuse_lora()
pipeline.unload_lora_weights()
if "Canny" in lora_ckpt_id:
control_image = load_image(

View File

@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
@@ -30,7 +30,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestFlux2LoRA(PeftLoraLoaderMixinTests):
pipeline_class = Flux2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -133,36 +133,36 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
assert np.isnan(out).all()
@unittest.skip("Not supported in Flux2.")
@pytest.mark.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux2.")
@pytest.mark.skip("Not supported in Flux2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Flux2.")
@pytest.mark.skip("Not supported in Flux2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -14,9 +14,9 @@
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -149,46 +149,41 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
# TODO(aryan): Fix the following test
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
def test_simple_inference_save_pretrained(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in HunyuanVideo.")
@pytest.mark.skip("Not supported in HunyuanVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@@ -197,7 +192,7 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
class TestHunyuanVideoLoRAIntegration:
"""internal note: The integration slices were obtained on DGX.
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
@@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
def setUp(self):
super().setUp()
@pytest.fixture(scope="function")
def pipeline(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -217,27 +211,27 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
self.pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, torch_dtype=torch.float16
).to(torch_device)
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
torch_device
)
try:
yield pipe
finally:
del pipe
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
def test_original_format_cseti(self):
self.pipeline.load_lora_weights(
def test_original_format_cseti(self, pipeline):
pipeline.load_lora_weights(
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
)
self.pipeline.fuse_lora()
self.pipeline.unload_lora_weights()
self.pipeline.vae.enable_tiling()
pipeline.fuse_lora()
pipeline.unload_lora_weights()
pipeline.vae.enable_tiling()
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
out = self.pipeline(
out = pipeline(
prompt=prompt,
height=320,
width=512,

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -108,40 +108,40 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in LTXVideo.")
@pytest.mark.skip("Not supported in LTXVideo.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import sys
import unittest
import numpy as np
import pytest
@@ -36,7 +35,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestLumina2LoRA(PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -101,35 +100,35 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Lumina2.")
@pytest.mark.skip("Not supported in Lumina2.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@@ -139,20 +138,17 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
def test_lora_fuse_nan(self, pipe):
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
@@ -166,4 +162,4 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
assert np.isnan(out).all()

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestMochiLoRA(PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -99,44 +99,44 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Mochi.")
@pytest.mark.skip("Not supported in Mochi.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
@pytest.mark.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass

View File

@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -96,34 +96,34 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Qwen Image.")
@pytest.mark.skip("Not supported in Qwen Image.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import Gemma2Model, GemmaTokenizer
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestSanaLoRA(PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0}
@@ -105,38 +105,38 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SANA.")
@pytest.mark.skip("Not supported in SANA.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in SANA.")
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference_denoiser(self):
return super().test_layerwise_casting_inference_denoiser()

View File

@@ -14,9 +14,9 @@
# limitations under the License.
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
@@ -55,7 +55,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusionPipeline
scheduler_cls = DDIMScheduler
scheduler_kwargs = {
@@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
# Keeping this test here makes sense because it doesn't look any integration
# (value assertions on logits).
@slow
@@ -114,15 +104,8 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
# We will offload the first adapter in CPU and check if the offloading
# has been performed correctly
@@ -130,35 +113,35 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
for name, module in pipe.text_encoder.named_modules():
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1"], 0)
for n, m in pipe.unet.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules():
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
for n, m in pipe.unet.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
for n, m in pipe.text_encoder.named_modules():
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
self.assertTrue(m.weight.device != torch.device("cpu"))
assert m.weight.device != torch.device("cpu")
@slow
@require_torch_accelerator
@@ -181,15 +164,9 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
for name, param in pipe.unet.named_parameters():
if "lora_" in name:
@@ -225,17 +202,14 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in unet",
)
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
self.assertTrue(modules_adapter_0 - modules_adapter_1)
self.assertTrue(modules_adapter_1 - modules_adapter_0)
assert modules_adapter_0 != modules_adapter_1
assert modules_adapter_0 - modules_adapter_1
assert modules_adapter_1 - modules_adapter_0
# setting both separately works
pipe.set_lora_device(["adapter-0"], "cpu")
@@ -243,32 +217,30 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device == torch.device("cpu"))
assert module.weight.device == torch.device("cpu")
# setting both at once also works
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
for name, module in pipe.unet.named_modules():
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
self.assertTrue(module.weight.device != torch.device("cpu"))
assert module.weight.device != torch.device("cpu")
@slow
@nightly
@require_torch_accelerator
@require_peft_backend
class LoraIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
class TestSDLoraIntegration:
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)
@@ -280,10 +252,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
prompt = "a red sks dog"
@@ -312,10 +281,7 @@ class LoraIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id)
pipe = pipe.to(torch_device)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder",
)
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
prompt = "a red sks dog"
@@ -587,8 +553,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
assert not np.allclose(initial_images, lora_images)
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
release_memory(pipe)
@@ -625,8 +591,8 @@ class LoraIntegrationTests(unittest.TestCase):
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
assert not np.allclose(initial_images, lora_images)
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
# make sure we can load a LoRA again after unloading and they don't have
# any undesired effects.
@@ -637,7 +603,7 @@ class LoraIntegrationTests(unittest.TestCase):
).images
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
assert np.allclose(lora_images, lora_images_again, atol=1e-3)
release_memory(pipe)
def test_not_empty_state_dict(self):
@@ -651,7 +617,7 @@ class LoraIntegrationTests(unittest.TestCase):
lcm_lora = load_file(cached_file)
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertTrue(lcm_lora != {})
assert lcm_lora != {}
release_memory(pipe)
def test_load_unload_load_state_dict(self):
@@ -666,11 +632,11 @@ class LoraIntegrationTests(unittest.TestCase):
previous_state_dict = lcm_lora.copy()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
assert lcm_lora == previous_state_dict
pipe.unload_lora_weights()
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
self.assertDictEqual(lcm_lora, previous_state_dict)
assert lcm_lora == previous_state_dict
release_memory(pipe)

View File

@@ -14,9 +14,9 @@
# limitations under the License.
import gc
import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -51,7 +51,7 @@ if is_accelerate_available():
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestSD3LoRA(PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -113,19 +113,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
lora_filename = "lora_peft_format.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in SD3.")
@pytest.mark.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass
@@ -138,17 +138,15 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
class TestSD3LoraIntegration:
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
def setUp(self):
super().setUp()
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)

View File

@@ -17,9 +17,9 @@ import gc
import importlib
import sys
import time
import unittest
import numpy as np
import pytest
import torch
from packaging import version
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
@@ -59,7 +59,7 @@ if is_accelerate_available():
from accelerate.utils import release_memory
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline
scheduler_cls = EulerDiscreteScheduler
@@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def output_shape(self):
return (1, 64, 64, 3)
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
def test_simple_inference_with_text_denoiser_lora_unfused(self):
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -127,10 +117,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3
super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -139,10 +129,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_rtol = 1e-3
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_lora_scale_kwargs_match_fusion(self):
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
@@ -150,21 +140,21 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
super().test_lora_scale_kwargs_match_fusion(
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
)
@slow
@nightly
@require_torch_accelerator
@require_peft_backend
class LoraSDXLIntegrationTests(unittest.TestCase):
def setUp(self):
super().setUp()
class TestLoraSDXLIntegration:
@pytest.fixture(autouse=True)
def _gc_and_cache_cleanup(self, torch_device):
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
yield
gc.collect()
backend_empty_cache(torch_device)
@@ -383,7 +373,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
end_time = time.time()
elapsed_time_fusion = end_time - start_time
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
assert elapsed_time_fusion < elapsed_time_non_fusion
release_memory(pipe)
@@ -439,14 +429,14 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
for key, value in text_encoder_1_sd.items():
key = remap_key(key, fused_te_state_dict)
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
assert torch.allclose(fused_te_state_dict[key], value)
for key, value in text_encoder_2_sd.items():
key = remap_key(key, fused_te_2_state_dict)
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
assert torch.allclose(fused_te_2_state_dict[key], value)
for key, value in unet_state_dict.items():
self.assertTrue(torch.allclose(unet_state_dict[key], value))
assert torch.allclose(unet_state_dict[key], value)
pipe.fuse_lora()
pipe.unload_lora_weights()
@@ -589,7 +579,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipe = pipe.to(torch_device)
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
prompt = "toy_face of a hacker with a hoodie"

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import sys
import unittest
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestWanLoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -104,40 +104,40 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Wan.")
@pytest.mark.skip("Not supported in Wan.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
def test_simple_inference_with_text_lora_save_load(self):
pass

View File

@@ -14,10 +14,9 @@
import os
import sys
import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
import torch
from PIL import Image
@@ -32,7 +31,6 @@ from ..testing_utils import (
require_peft_backend,
require_peft_version_greater,
skip_mps,
torch_device,
)
@@ -47,7 +45,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class TestWanVACELoRA(PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
@@ -121,56 +119,51 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Wan VACE.")
@pytest.mark.skip("Not supported in Wan VACE.")
def test_modify_padding_mode(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_partial_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_and_scale(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_fused(self):
pass
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
def test_simple_inference_with_text_lora_save_load(self):
pass
def test_layerwise_casting_inference_denoiser(self):
super().test_layerwise_casting_inference_denoiser()
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
assert base_pipe_output.shape == self.output_shape
# only supported for `denoiser` now
denoiser_lora_config.target_modules = ["proj_out"]
@@ -180,36 +173,30 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
)
# The state dict shouldn't contain the modules to be excluded from LoRA.
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
pipe.unload_lora_weights()
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
pipe.unload_lora_weights()
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
# Check in the loaded state dict.
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert not any(exclude_module_name in k for k in loaded_state_dict)
assert any("proj_out" in k for k in loaded_state_dict)
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdir)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
# Check in the state dict obtained after loading LoRA.
pipe.load_lora_weights(tmpdirname)
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
assert not any(exclude_module_name in k for k in state_dict_from_model)
assert any("proj_out" in k for k in state_dict_from_model)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
"LoRA should change outputs.",
)
self.assertTrue(
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
"Lora outputs should match.",
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
super().test_simple_inference_with_text_denoiser_lora_and_scale()
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
"LoRA should change outputs."
)
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
"Lora outputs should match."
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,93 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import PIL
import pytest
from diffusers.modular_pipelines import (
Flux2AutoBlocks,
Flux2ModularPipeline,
)
from ...testing_utils import floats_tensor, torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale"])
batch_params = frozenset(["prompt"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
pipeline_class = Flux2ModularPipeline
pipeline_blocks_class = Flux2AutoBlocks
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
batch_params = frozenset(["prompt", "image"])
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
"text_encoder_out_layers": (1,),
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 4.0,
"height": 32,
"width": 32,
"output_type": "pt",
}
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
inputs["image"] = init_image
return inputs
def test_float16_inference(self):
super().test_float16_inference(9e-2)
@pytest.mark.skip(reason="batched inference is currently not supported")
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
return

View File

@@ -165,6 +165,7 @@ class ModularPipelineTesterMixin:
expected_max_diff=1e-4,
):
pipe = self.get_pipeline().to(torch_device)
inputs = self.get_dummy_inputs()
# Reset generator in case it is has been used in self.get_dummy_inputs