mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
31 Commits
remove-lor
...
migrate-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
db250958c5 | ||
|
|
f956ba0db1 | ||
|
|
f3593a8aa9 | ||
|
|
1b6cdea043 | ||
|
|
3fb66f23ac | ||
|
|
9c3bed1783 | ||
|
|
11b80d09b0 | ||
|
|
9201505554 | ||
|
|
eece7120dd | ||
|
|
2e42205c3a | ||
|
|
757bbf7b05 | ||
|
|
4561c065aa | ||
|
|
4ae5772fef | ||
|
|
0d3da485a0 | ||
|
|
4f5e9a665e | ||
|
|
23e5559c54 | ||
|
|
f8f27891c6 | ||
|
|
128535cfcd | ||
|
|
bdc9537999 | ||
|
|
dae161ed26 | ||
|
|
c4bcf72084 | ||
|
|
1737b710a2 | ||
|
|
565d674cc4 | ||
|
|
610842af1a | ||
|
|
cba82591e8 | ||
|
|
949cc1c326 | ||
|
|
ec866f5de8 | ||
|
|
7b4bcce602 | ||
|
|
d61bb38fb4 | ||
|
|
9e92f6bb63 | ||
|
|
6c6cade1a7 |
@@ -237,8 +237,6 @@ By selectively loading and unloading the models you need at a given stage and sh
|
||||
|
||||
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
|
||||
|
||||
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
|
||||
|
||||
### Ring Attention
|
||||
|
||||
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
|
||||
@@ -247,60 +245,40 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
|
||||
|
||||
```py
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from diffusers import DiffusionPipeline, ContextParallelConfig
|
||||
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
|
||||
|
||||
def setup_distributed():
|
||||
if not dist.is_initialized():
|
||||
dist.init_process_group(backend="nccl")
|
||||
rank = dist.get_rank()
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
try:
|
||||
torch.distributed.init_process_group("nccl")
|
||||
rank = torch.distributed.get_rank()
|
||||
device = torch.device("cuda", rank % torch.cuda.device_count())
|
||||
torch.cuda.set_device(device)
|
||||
return device
|
||||
|
||||
def main():
|
||||
device = setup_distributed()
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_native_cudnn")
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=world_size)
|
||||
pipeline.transformer.enable_parallelism(config=cp_config)
|
||||
|
||||
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
|
||||
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
pipeline.transformer.set_attention_backend("flash")
|
||||
|
||||
prompt = """
|
||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||
"""
|
||||
|
||||
|
||||
# Must specify generator so all ranks start with same latents (or pass your own)
|
||||
generator = torch.Generator().manual_seed(42)
|
||||
image = pipeline(
|
||||
prompt,
|
||||
guidance_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
|
||||
|
||||
if rank == 0:
|
||||
image.save("output.png")
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
image.save(f"output.png")
|
||||
except Exception as e:
|
||||
print(f"An error occurred: {e}")
|
||||
torch.distributed.breakpoint()
|
||||
raise
|
||||
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
finally:
|
||||
if torch.distributed.is_initialized():
|
||||
torch.distributed.destroy_process_group()
|
||||
```
|
||||
|
||||
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
|
||||
|
||||
/```shell
|
||||
`torchrun --nproc-per-node 2 above_script.py`.
|
||||
/```
|
||||
|
||||
### Ulysses Attention
|
||||
|
||||
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
|
||||
@@ -310,26 +288,5 @@ The script above needs to be run with a distributed launcher, such as [torchrun]
|
||||
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
# Depending on the number of GPUs available.
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||
```
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
```py
|
||||
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=2)
|
||||
transformer = AutoModel.from_pretrained(
|
||||
CKPT_ID,
|
||||
subfolder="transformer",
|
||||
torch_dtype=torch.bfloat16,
|
||||
parallel_config=cp_config
|
||||
)
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||
).to(device)
|
||||
```
|
||||
@@ -404,8 +404,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
@@ -1113,8 +1111,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
|
||||
@@ -1487,7 +1487,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
Load LoRA layers into [`FluxTransformer2DModel`],
|
||||
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
||||
|
||||
Specific to [`FluxPipeline`].
|
||||
Specific to [`StableDiffusion3Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer", "text_encoder"]
|
||||
@@ -1628,7 +1628,30 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
||||
loaded.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -3628,17 +3651,44 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted on the Hub.
|
||||
- A path to a *directory* containing the model weights.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Whether to use safetensors for loading.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
# Load the main state dict first which has the LoRA layers
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -3681,7 +3731,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -3690,13 +3739,26 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
hotswap (`bool`, *optional*):
|
||||
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
@@ -3713,6 +3775,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
# Load LoRA into transformer
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
@@ -3724,7 +3787,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
@@ -3736,9 +3798,23 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
Load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters.
|
||||
transformer (`Kandinsky5Transformer3DModel`):
|
||||
The transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
@@ -3756,7 +3832,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -3765,10 +3840,24 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
transformer_lora_adapter_metadata=None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
Save the LoRA parameters corresponding to the transformer and text encoders.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
@@ -3778,7 +3867,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers`")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
@@ -3790,7 +3879,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3800,7 +3888,25 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
|
||||
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
pipeline.load_lora_weights("path/to/lora.safetensors")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
@@ -3810,10 +3916,12 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
@@ -52,10 +52,6 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageModularPipeline",
|
||||
@@ -79,7 +75,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = [
|
||||
"Flux2TextEncoderStep",
|
||||
"Flux2RemoteTextEncoderStep",
|
||||
"Flux2VaeEncoderStep",
|
||||
]
|
||||
_import_structure["before_denoise"] = [
|
||||
"Flux2SetTimestepsStep",
|
||||
"Flux2PrepareLatentsStep",
|
||||
"Flux2RoPEInputsStep",
|
||||
"Flux2PrepareImageLatentsStep",
|
||||
]
|
||||
_import_structure["denoise"] = [
|
||||
"Flux2LoopDenoiser",
|
||||
"Flux2LoopAfterDenoiser",
|
||||
"Flux2DenoiseLoopWrapper",
|
||||
"Flux2DenoiseStep",
|
||||
]
|
||||
_import_structure["decoders"] = ["Flux2DecodeStep"]
|
||||
_import_structure["inputs"] = [
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import (
|
||||
Flux2DenoiseLoopWrapper,
|
||||
Flux2DenoiseStep,
|
||||
Flux2LoopAfterDenoiser,
|
||||
Flux2LoopDenoiser,
|
||||
)
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
REMOTE_AUTO_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,508 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
"""Compute empirical mu for Flux2 timestep scheduling."""
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
|
||||
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
timesteps = block_state.timesteps
|
||||
|
||||
if timesteps is None and sigmas is None:
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
block_state.device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
|
||||
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_ids(latents: torch.Tensor):
|
||||
"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents: Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
Position IDs tensor of shape (B, H*W, 4)
|
||||
"""
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1)
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1)
|
||||
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents):
|
||||
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def prepare_latents(
|
||||
comp,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
latents = self.prepare_latents(
|
||||
components,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(block_state.device)
|
||||
|
||||
latents = self._pack_latents(latents)
|
||||
|
||||
block_state.latents = latents
|
||||
block_state.latent_ids = latent_ids
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image_latents", type_hint=List[torch.Tensor]),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_images_per_prompt", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning",
|
||||
),
|
||||
OutputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
|
||||
"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
Args:
|
||||
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
|
||||
scale: Factor used to define the time separation between latents.
|
||||
|
||||
Returns:
|
||||
Combined coordinate tensor of shape (1, N_total, 4)
|
||||
"""
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents):
|
||||
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
image_latents = block_state.image_latents
|
||||
|
||||
if image_latents is None:
|
||||
block_state.image_latents = None
|
||||
block_state.image_latent_ids = None
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
packed = self._pack_latents(latent)
|
||||
packed = packed.squeeze(0)
|
||||
packed_latents.append(packed)
|
||||
|
||||
image_latents = torch.cat(packed_latents, dim=0)
|
||||
image_latents = image_latents.unsqueeze(0)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
block_state.image_latents = image_latents
|
||||
block_state.image_latent_ids = image_latent_ids
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1,146 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for the latents, used for unpacking",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Unpack latents using position IDs to scatter tokens into place.
|
||||
|
||||
Args:
|
||||
x: Packed latents tensor of shape (B, seq_len, C)
|
||||
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
|
||||
|
||||
Returns:
|
||||
Unpacked latents tensor of shape (B, C, H, W)
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1,252 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Mistral3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=block_state.guidance,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that updates the latents after denoising. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [InputParam("generator")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoises the latents over `timesteps`. "
|
||||
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2LoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
@@ -1,349 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def format_text_input(prompts: List[str], system_message: str = None):
|
||||
"""Format prompts for Mistral3 chat template."""
|
||||
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
# fmt: off
|
||||
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
|
||||
ComponentSpec("tokenizer", AutoProcessor),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Mistral3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_mistral_3_prompt_embeds(
|
||||
text_encoder: Mistral3ForConditionalGeneration,
|
||||
tokenizer: AutoProcessor,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
# fmt: off
|
||||
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||
# fmt: on
|
||||
hidden_states_layers: Tuple[int] = (10, 20, 30),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=block_state.device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
system_message=self.DEFAULT_SYSTEM_MESSAGE,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using a remote API endpoint"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from remote API used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
import io
|
||||
|
||||
import requests
|
||||
from huggingface_hub import get_token
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
response = requests.post(
|
||||
self.REMOTE_URL,
|
||||
json={"prompt": prompt},
|
||||
headers={
|
||||
"Authorization": f"Bearer {get_token()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKLFlux2)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("condition_images", type_hint=List[torch.Tensor]),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="List of latent representations for each reference image",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _patchify_latents(latents):
|
||||
"""Convert latents to patchified format for Flux2."""
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
|
||||
"""Encode a single image using Flux2 VAE with batch norm normalization."""
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
|
||||
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
condition_images = block_state.condition_images
|
||||
|
||||
if condition_images is None:
|
||||
return components, state
|
||||
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image_latents = []
|
||||
for image in condition_images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
latent = self._encode_vae_image(
|
||||
vae=components.vae,
|
||||
image=image,
|
||||
generator=block_state.generator,
|
||||
)
|
||||
image_latents.append(latent)
|
||||
|
||||
block_state.image_latents = image_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1,160 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
images = block_state.image
|
||||
|
||||
if images is None:
|
||||
block_state.condition_images = None
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
condition_images = []
|
||||
for img in images:
|
||||
components.image_processor.check_image_input(img)
|
||||
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = components.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
condition_img = components.image_processor.preprocess(
|
||||
img, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
condition_images.append(condition_img)
|
||||
|
||||
if block_state.height is None:
|
||||
block_state.height = image_height
|
||||
if block_state.width is None:
|
||||
block_state.width = image_width
|
||||
|
||||
block_state.condition_images = condition_images
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
@@ -1,166 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2VaeEncoderBlocks.values()
|
||||
block_names = Flux2VaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
|
||||
|
||||
|
||||
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2VaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
|
||||
"- For text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("preprocess_images", Flux2ProcessImagesInputStep()),
|
||||
("vae_encoder", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"remote": REMOTE_AUTO_BLOCKS,
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2AutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
@@ -58,7 +58,6 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
@@ -1587,6 +1586,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
|
||||
@property
|
||||
|
||||
@@ -2,36 +2,6 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class Flux2AutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class FluxAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -13,16 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AuraFlowPipeline,
|
||||
AuraFlowTransformer2DModel,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers import AuraFlowPipeline, AuraFlowTransformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
@@ -40,7 +36,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestAuraFlowLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = AuraFlowPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -103,34 +99,34 @@ class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Not supported in AuraFlow.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in AuraFlow.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,10 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
@@ -39,7 +38,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestCogVideoXLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
scheduler_cls = CogVideoXDPMScheduler
|
||||
scheduler_kwargs = {"timestep_spacing": "trailing"}
|
||||
@@ -119,54 +118,59 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3, pipe=pipe)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=9e-3, expected_rtol=9e-3
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
|
||||
@@ -28,7 +25,6 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_torch_accelerator,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +43,7 @@ class TokenizerWrapper:
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestCogView4LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = CogView4Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -113,72 +109,50 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
components, _, _ = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
@parameterized.expand([("block_level", True), ("leaf_level", False)])
|
||||
@pytest.mark.parametrize(
|
||||
"offload_type, use_stream",
|
||||
[("block_level", True), ("leaf_level", False)],
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
|
||||
def test_group_offloading_inference_denoiser(self, offload_type, use_stream, tmpdirname, pipe):
|
||||
# TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
|
||||
# The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
|
||||
super()._test_group_offloading_inference_denoiser(offload_type, use_stream, tmpdirname, pipe)
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogView4.")
|
||||
@pytest.mark.skip("Not supported in CogView4.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in CogView4.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -16,13 +16,11 @@ import copy
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -46,14 +44,12 @@ from ..testing_utils import (
|
||||
|
||||
if is_peft_available():
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFluxLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -115,165 +111,134 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_alpha_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_with_alpha_in_state_dict(self, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
# only do for `transformer` and for the k projections -- should be enough to test.
|
||||
if "transformer" in k and "to_k" in k and "lora_A" in k:
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
# modify the state dict to have alpha values following
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA/blob/main/jon_snow.safetensors
|
||||
state_dict_with_alpha = safetensors.torch.load_file(
|
||||
os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")
|
||||
)
|
||||
alpha_dict = {}
|
||||
for k, v in state_dict_with_alpha.items():
|
||||
if "transformer" in k and "to_k" in k and ("lora_A" in k):
|
||||
alpha_dict[f"{k}.alpha"] = float(torch.randint(10, 100, size=()))
|
||||
state_dict_with_alpha.update(alpha_dict)
|
||||
|
||||
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(state_dict_with_alpha)
|
||||
images_lora_with_alpha = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
assert np.allclose(images_lora, images_lora_from_pretrained, atol=0.001, rtol=0.001), (
|
||||
"Loading from saved checkpoints should give same results."
|
||||
)
|
||||
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(images_lora_with_alpha, images_lora, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_expansion_works_for_absent_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_expansion_works_for_absent_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the second LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="one")
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="two")
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora_with_absent_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
assert not np.allclose(images_lora, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_absent_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_absent_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
def test_lora_expansion_works_for_extra_keys(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_expansion_works_for_extra_keys(self, base_pipe_output, tmpdirname, pipe):
|
||||
_, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
|
||||
# Modify the config to have a layer which won't be present in the first LoRA we will load.
|
||||
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
|
||||
modified_denoiser_lora_config.target_modules.add("x_embedder")
|
||||
|
||||
pipe.transformer.add_adapter(modified_denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, output_no_lora, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(images_lora, base_pipe_output, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
self.pipeline_class.save_lora_weights(tmpdirname, transformer_lora_layers=denoiser_state_dict)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
# Modify the state dict to exclude "x_embedder" related LoRA params.
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for k, v in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
|
||||
# Load state dict with `x_embedder`.
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
pipe.unload_lora_weights()
|
||||
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
lora_state_dict_without_xembedder = {k: v for (k, v) in lora_state_dict.items() if "x_embedder" not in k}
|
||||
pipe.load_lora_weights(lora_state_dict_without_xembedder, adapter_name="one")
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"), adapter_name="two")
|
||||
|
||||
pipe.set_adapters(["one", "two"])
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer"
|
||||
|
||||
images_lora_with_extra_keys = pipe(**inputs, generator=torch.manual_seed(0)).images
|
||||
|
||||
self.assertFalse(
|
||||
np.allclose(images_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"Different LoRAs should lead to different results.",
|
||||
assert not np.allclose(images_lora, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"Different LoRAs should lead to different results."
|
||||
)
|
||||
self.assertFalse(
|
||||
np.allclose(output_no_lora, images_lora_with_extra_keys, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should lead to different results.",
|
||||
assert not np.allclose(base_pipe_output, images_lora_with_extra_keys, atol=0.001, rtol=0.001), (
|
||||
"LoRA should lead to different results."
|
||||
)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
|
||||
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFluxControlLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = FluxControlPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -338,12 +303,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_with_norm_in_state_dict(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def test_with_norm_in_state_dict(self, pipe):
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
@@ -364,39 +324,32 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
lora_load_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(
|
||||
assert (
|
||||
"The provided state dict contains normalization layers in addition to LoRA layers"
|
||||
in cap_logger.out
|
||||
)
|
||||
self.assertTrue(len(pipe.transformer._transformer_norm_layers) > 0)
|
||||
assert len(pipe.transformer._transformer_norm_layers) > 0
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
lora_unload_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
|
||||
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
|
||||
self.assertFalse(
|
||||
np.allclose(original_output, lora_load_output, atol=1e-6, rtol=1e-6), f"{norm_layer} is tested"
|
||||
assert pipe.transformer._transformer_norm_layers is None
|
||||
assert np.allclose(original_output, lora_unload_output, atol=1e-05, rtol=1e-05)
|
||||
assert not np.allclose(original_output, lora_load_output, atol=1e-06, rtol=1e-06), (
|
||||
f"{norm_layer} is tested"
|
||||
)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
for key in list(norm_state_dict.keys()):
|
||||
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
|
||||
pipe.load_lora_weights(norm_state_dict)
|
||||
assert "Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
|
||||
self.assertTrue(
|
||||
"Unsupported keys found in state dict when trying to load normalization layers" in cap_logger.out
|
||||
)
|
||||
|
||||
def test_lora_parameter_expanded_shapes(self):
|
||||
def test_lora_parameter_expanded_shapes(self, pipe):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -405,24 +358,21 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
original_transformer_state_dict = pipe.transformer.state_dict()
|
||||
x_embedder_weight = original_transformer_state_dict.pop("x_embedder.weight")
|
||||
incompatible_keys = transformer.load_state_dict(original_transformer_state_dict, strict=False)
|
||||
self.assertTrue(
|
||||
"x_embedder.weight" in incompatible_keys.missing_keys,
|
||||
"Could not find x_embedder.weight in the missing keys.",
|
||||
assert "x_embedder.weight" in incompatible_keys.missing_keys, (
|
||||
"Could not find x_embedder.weight in the missing keys."
|
||||
)
|
||||
|
||||
transformer.x_embedder.weight.data.copy_(x_embedder_weight[..., :num_channels_without_control])
|
||||
pipe.transformer = transformer
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
|
||||
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
@@ -431,15 +381,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
# Testing opposite direction where the LoRA params are zero-padded.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -454,15 +402,13 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
|
||||
def test_normal_lora_with_expanded_lora_raises_error(self):
|
||||
# Test the following situation. Load a regular LoRA (such as the ones trained on Flux.1-Dev). And then
|
||||
@@ -494,32 +440,28 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-1"])
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.get_active_adapters() == ["adapter-1"]
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
(_, _, inputs) = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.get_active_adapters() == ["adapter-2"])
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.get_active_adapters() == ["adapter-2"]
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
|
||||
# Test the opposite case where the first lora has the correct input features and the second lora has expanded input features.
|
||||
# This should raise a runtime error on input shapes being incompatible.
|
||||
@@ -540,32 +482,24 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": shape_expander_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
|
||||
# We should check for input shapes being incompatible here. But because above mentioned issue is
|
||||
# not a supported use case, and because of the PEFT renaming, we will currently have a shape
|
||||
# mismatch error.
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"size mismatch for x_embedder.lora_A.adapter-2.weight",
|
||||
pipe.load_lora_weights,
|
||||
lora_state_dict,
|
||||
"adapter-2",
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="size mismatch for x_embedder.lora_A.adapter-2.weight"):
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
|
||||
def test_fuse_expanded_lora_with_regular_lora(self):
|
||||
# This test checks if it works when a lora with expanded shapes (like control loras) but
|
||||
@@ -597,7 +531,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": shape_expander_lora_B.weight,
|
||||
}
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -610,54 +544,44 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-2")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.set_adapters(["adapter-1", "adapter-2"], [1.0, 1.0])
|
||||
lora_output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_2, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
self.assertFalse(np.allclose(lora_output_2, lora_output_3, atol=1e-3, rtol=1e-3))
|
||||
assert not np.allclose(lora_output, lora_output_2, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output, lora_output_3, atol=0.001, rtol=0.001)
|
||||
assert not np.allclose(lora_output_2, lora_output_3, atol=0.001, rtol=0.001)
|
||||
|
||||
pipe.fuse_lora(lora_scale=1.0, adapter_names=["adapter-1", "adapter-2"])
|
||||
lora_output_4 = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(np.allclose(lora_output_3, lora_output_4, atol=1e-3, rtol=1e-3))
|
||||
assert np.allclose(lora_output_3, lora_output_4, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_load_regular_lora(self):
|
||||
def test_load_regular_lora(self, base_pipe_output, pipe):
|
||||
# This test checks if a regular lora (think of one trained on Flux.1 Dev for example) can be loaded
|
||||
# into the transformer with more input channels than Flux.1 Dev, for example. Some examples of those
|
||||
# transformers include Flux Fill, Flux Control, etc.
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
rank = 4
|
||||
in_features = in_features // 2 # to mimic the Flux.1-Dev LoRA.
|
||||
in_features = in_features // 2
|
||||
normal_lora_A = torch.nn.Linear(in_features, rank, bias=False)
|
||||
normal_lora_B = torch.nn.Linear(rank, out_features, bias=False)
|
||||
lora_state_dict = {
|
||||
"transformer.x_embedder.lora_A.weight": normal_lora_A.weight,
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
lora_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertTrue("The following LoRA modules were zero padded to match the state dict of" in cap_logger.out)
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertFalse(np.allclose(original_output, lora_output, atol=1e-3, rtol=1e-3))
|
||||
assert "The following LoRA modules were zero padded to match the state dict of" in cap_logger.out
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert not np.allclose(base_pipe_output, lora_output, atol=0.001, rtol=0.001)
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -670,9 +594,8 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -697,33 +620,31 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=True)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
assert control_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
loaded_pipe = FluxPipeline.from_pipe(control_pipe)
|
||||
self.assertTrue(
|
||||
loaded_pipe.transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {loaded_pipe.transformer.config.in_channels=}",
|
||||
assert loaded_pipe.transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has loaded_pipe.transformer.config.in_channels={loaded_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
inputs.pop("control_image")
|
||||
unloaded_lora_out = loaded_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(unloaded_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(np.allclose(unloaded_lora_out, original_out, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features)
|
||||
assert not np.allclose(unloaded_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert np.allclose(unloaded_lora_out, original_out, atol=0.0001, rtol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features
|
||||
assert pipe.transformer.config.in_channels == in_features
|
||||
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
@@ -731,14 +652,12 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
num_channels_without_control = 4
|
||||
transformer = FluxTransformer2DModel.from_config(
|
||||
components["transformer"].config, in_channels=num_channels_without_control
|
||||
).to(torch_device)
|
||||
self.assertTrue(
|
||||
transformer.config.in_channels == num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {transformer.config.in_channels=}",
|
||||
assert transformer.config.in_channels == num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has transformer.config.in_channels={transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
# This should be initialized with a Flux pipeline variant that doesn't accept `control_image`.
|
||||
@@ -763,40 +682,38 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
control_pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser")
|
||||
assert check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in denoiser"
|
||||
|
||||
inputs["control_image"] = control_image
|
||||
lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertFalse(np.allclose(original_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == 2 * in_features)
|
||||
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
|
||||
assert not np.allclose(original_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == 2 * in_features
|
||||
assert pipe.transformer.config.in_channels == 2 * in_features
|
||||
assert cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")
|
||||
|
||||
control_pipe.unload_lora_weights(reset_to_overwritten_params=False)
|
||||
self.assertTrue(
|
||||
control_pipe.transformer.config.in_channels == 2 * num_channels_without_control,
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has {control_pipe.transformer.config.in_channels=}",
|
||||
assert control_pipe.transformer.config.in_channels == 2 * num_channels_without_control, (
|
||||
f"Expected {num_channels_without_control} channels in the modified transformer but has control_pipe.transformer.config.in_channels={control_pipe.transformer.config.in_channels!r}"
|
||||
)
|
||||
|
||||
no_lora_out = control_pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(no_lora_out, lora_out, rtol=0.0001, atol=0.0001)
|
||||
assert pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2
|
||||
assert pipe.transformer.config.in_channels == in_features * 2
|
||||
|
||||
self.assertFalse(np.allclose(no_lora_out, lora_out, rtol=1e-4, atol=1e-4))
|
||||
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
|
||||
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux.")
|
||||
@pytest.mark.skip("Not supported in Flux.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -806,7 +723,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestFluxLoRAIntegration:
|
||||
"""internal note: The integration slices were obtained on audace.
|
||||
|
||||
torch: 2.6.0.dev20241006+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -816,33 +733,27 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.pipeline
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self):
|
||||
self.pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
# Instead of calling `enable_model_cpu_offload()`, we do a accelerator placement here because the CI
|
||||
# run supports it. We have about 34GB RAM in the CI runner which kills the test when run with
|
||||
# `enable_model_cpu_offload()`. We repeat this for the other tests, too.
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_the_last_ben(self, pipeline):
|
||||
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "jon snow eating pizza with ketchup"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.0,
|
||||
@@ -851,71 +762,57 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.1855, 0.1855, 0.1836, 0.1855, 0.1836, 0.1875, 0.1777, 0.1758, 0.2246])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya(self):
|
||||
self.pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_kohya(self, pipeline):
|
||||
pipeline.load_lora_weights("Norod78/brain-slug-flux")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "The cat with a brain slug earring"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.6367, 0.6367, 0.6328, 0.6367, 0.6328, 0.6289, 0.6367, 0.6328, 0.6484])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self):
|
||||
self.pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
|
||||
def test_flux_kohya_with_text_encoder(self, pipeline):
|
||||
pipeline.load_lora_weights("cocktailpeanut/optimus", weight_name="optimus.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "optimus is cleaning the house with broomstick"
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=4.5,
|
||||
output_type="np",
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4023, 0.4023, 0.4023, 0.3965, 0.3984, 0.3965, 0.3926, 0.3906, 0.4219])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_kohya_embedders_conversion(self):
|
||||
def test_flux_kohya_embedders_conversion(self, pipeline):
|
||||
"""Test that embedders load without throwing errors"""
|
||||
self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
self.pipeline.unload_lora_weights()
|
||||
|
||||
assert True
|
||||
|
||||
def test_flux_xlabs(self):
|
||||
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline = self.pipeline.to(torch_device)
|
||||
pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
def test_flux_xlabs(self, pipeline):
|
||||
pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline = pipeline.to(torch_device)
|
||||
prompt = "A blue jay standing on a large basket of rainbow macarons, disney style"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -923,23 +820,17 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
generator=torch.manual_seed(self.seed),
|
||||
).images
|
||||
out_slice = out[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.3965, 0.4180, 0.4434, 0.4082, 0.4375, 0.4590, 0.4141, 0.4375, 0.4980])
|
||||
|
||||
expected_slice = np.array([0.3965, 0.418, 0.4434, 0.4082, 0.4375, 0.459, 0.4141, 0.4375, 0.498])
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
assert max_diff < 0.001
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
"salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.enable_model_cpu_offload()
|
||||
|
||||
def test_flux_xlabs_load_lora_with_single_blocks(self, pipeline):
|
||||
pipeline.load_lora_weights("salinasr/test_xlabs_flux_lora_with_singleblocks", weight_name="lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
prompt = "a wizard mouse playing chess"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt,
|
||||
num_inference_steps=self.num_inference_steps,
|
||||
guidance_scale=3.5,
|
||||
@@ -951,40 +842,43 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
|
||||
[0.04882812, 0.04101562, 0.04882812, 0.03710938, 0.02929688, 0.02734375, 0.0234375, 0.01757812, 0.0390625]
|
||||
)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), out_slice)
|
||||
|
||||
assert max_diff < 1e-3
|
||||
assert max_diff < 0.001
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestFluxControlLoRAIntegration:
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
prompt = "A robot made of exotic candies and chocolates of different kinds."
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
pipe = FluxControlPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
self.pipeline = FluxControlPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
@@ -995,7 +889,7 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux-control-lora/depth_condition_image.png"
|
||||
)
|
||||
|
||||
image = self.pipeline(
|
||||
image = pipeline(
|
||||
prompt=self.prompt,
|
||||
control_image=control_image,
|
||||
height=1024,
|
||||
@@ -1016,12 +910,18 @@ class FluxControlLoRAIntegrationTests(unittest.TestCase):
|
||||
|
||||
assert max_diff < 1e-3
|
||||
|
||||
@parameterized.expand(["black-forest-labs/FLUX.1-Canny-dev-lora", "black-forest-labs/FLUX.1-Depth-dev-lora"])
|
||||
def test_lora_with_turbo(self, lora_ckpt_id):
|
||||
self.pipeline.load_lora_weights(lora_ckpt_id)
|
||||
self.pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
@pytest.mark.parametrize(
|
||||
"lora_ckpt_id",
|
||||
[
|
||||
"black-forest-labs/FLUX.1-Canny-dev-lora",
|
||||
"black-forest-labs/FLUX.1-Depth-dev-lora",
|
||||
],
|
||||
)
|
||||
def test_lora_with_turbo(self, pipeline, lora_ckpt_id):
|
||||
pipeline.load_lora_weights(lora_ckpt_id)
|
||||
pipeline.load_lora_weights("ByteDance/Hyper-SD", weight_name="Hyper-FLUX.1-dev-8steps-lora.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
if "Canny" in lora_ckpt_id:
|
||||
control_image = load_image(
|
||||
|
||||
@@ -13,9 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
@@ -30,7 +30,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestFlux2LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Flux2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -133,36 +133,36 @@ class Flux2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
assert np.isnan(out).all()
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Flux2.")
|
||||
@pytest.mark.skip("Not supported in Flux2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Flux2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
@@ -48,7 +48,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestHunyuanVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -149,46 +149,41 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
# TODO(aryan): Fix the following test
|
||||
@unittest.skip("This test fails with an error I haven't been able to debug yet.")
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Not supported in HunyuanVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in HunyuanVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -197,7 +192,7 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
class TestHunyuanVideoLoRAIntegration:
|
||||
"""internal note: The integration slices were obtained on DGX.
|
||||
|
||||
torch: 2.5.1+cu124 with CUDA 12.5. Need the same setup for the
|
||||
@@ -207,9 +202,8 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
num_inference_steps = 10
|
||||
seed = 0
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def pipeline(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -217,27 +211,27 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
self.pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.float16).to(
|
||||
torch_device
|
||||
)
|
||||
try:
|
||||
yield pipe
|
||||
finally:
|
||||
del pipe
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def test_original_format_cseti(self):
|
||||
self.pipeline.load_lora_weights(
|
||||
def test_original_format_cseti(self, pipeline):
|
||||
pipeline.load_lora_weights(
|
||||
"Cseti/HunyuanVideo-LoRA-Arcane_Jinx-v1", weight_name="csetiarcane-nfjinx-v1-6000.safetensors"
|
||||
)
|
||||
self.pipeline.fuse_lora()
|
||||
self.pipeline.unload_lora_weights()
|
||||
self.pipeline.vae.enable_tiling()
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
pipeline.vae.enable_tiling()
|
||||
|
||||
prompt = "CSETIARCANE. A cat walks on the grass, realistic"
|
||||
|
||||
out = self.pipeline(
|
||||
out = pipeline(
|
||||
prompt=prompt,
|
||||
height=320,
|
||||
width=512,
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestLTXVideoLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = LTXPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -108,40 +108,40 @@ class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Not supported in LTXVideo.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in LTXVideo.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -36,7 +35,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestLumina2LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = Lumina2Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -101,35 +100,35 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Lumina2.")
|
||||
@pytest.mark.skip("Not supported in Lumina2.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Lumina2.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -139,20 +138,17 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
|
||||
strict=False,
|
||||
)
|
||||
def test_lora_fuse_nan(self):
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
def test_lora_fuse_nan(self, pipe):
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
assert check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser."
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
@@ -166,4 +162,4 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
assert np.isnan(out).all()
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestMochiLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = MochiPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -99,44 +99,44 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Mochi.")
|
||||
@pytest.mark.skip("Not supported in Mochi.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Mochi.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in CogVideoX.")
|
||||
@pytest.mark.skip("Not supported in CogVideoX.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
|
||||
|
||||
@@ -34,7 +34,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestQwenImageLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = QwenImagePipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -96,34 +96,34 @@ class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Not supported in Qwen Image.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Qwen Image.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import Gemma2Model, GemmaTokenizer
|
||||
|
||||
@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestSanaLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = SanaPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {"shift": 7.0}
|
||||
@@ -105,38 +105,38 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SANA.")
|
||||
@pytest.mark.skip("Not supported in SANA.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in SANA.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in SANA.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
return super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
@@ -55,7 +55,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
class TestStableDiffusionLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
scheduler_cls = DDIMScheduler
|
||||
scheduler_kwargs = {
|
||||
@@ -91,16 +91,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
# Keeping this test here makes sense because it doesn't look any integration
|
||||
# (value assertions on logits).
|
||||
@slow
|
||||
@@ -114,15 +104,8 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id, adapter_name="adapter-2")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
# We will offload the first adapter in CPU and check if the offloading
|
||||
# has been performed correctly
|
||||
@@ -130,35 +113,35 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
for name, module in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-2" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
pipe.set_lora_device(["adapter-1"], 0)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if "adapter-1" in n and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
pipe.set_lora_device(["adapter-1", "adapter-2"], torch_device)
|
||||
|
||||
for n, m in pipe.unet.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
for n, m in pipe.text_encoder.named_modules():
|
||||
if ("adapter-1" in n or "adapter-2" in n) and not isinstance(m, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(m.weight.device != torch.device("cpu"))
|
||||
assert m.weight.device != torch.device("cpu")
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
@@ -181,15 +164,9 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
for name, param in pipe.unet.named_parameters():
|
||||
if "lora_" in name:
|
||||
@@ -225,17 +202,14 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
pipe.unet.add_adapter(config1, adapter_name="adapter-1")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.unet),
|
||||
"Lora not correctly set in unet",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet"
|
||||
|
||||
# sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
|
||||
modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
|
||||
modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
|
||||
self.assertNotEqual(modules_adapter_0, modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_0 - modules_adapter_1)
|
||||
self.assertTrue(modules_adapter_1 - modules_adapter_0)
|
||||
assert modules_adapter_0 != modules_adapter_1
|
||||
assert modules_adapter_0 - modules_adapter_1
|
||||
assert modules_adapter_1 - modules_adapter_0
|
||||
|
||||
# setting both separately works
|
||||
pipe.set_lora_device(["adapter-0"], "cpu")
|
||||
@@ -243,32 +217,30 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device == torch.device("cpu"))
|
||||
assert module.weight.device == torch.device("cpu")
|
||||
|
||||
# setting both at once also works
|
||||
pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
|
||||
|
||||
for name, module in pipe.unet.named_modules():
|
||||
if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
|
||||
self.assertTrue(module.weight.device != torch.device("cpu"))
|
||||
assert module.weight.device != torch.device("cpu")
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class TestSDLoraIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -280,10 +252,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -312,10 +281,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
assert check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
|
||||
prompt = "a red sks dog"
|
||||
|
||||
@@ -587,8 +553,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -625,8 +591,8 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertFalse(np.allclose(initial_images, lora_images))
|
||||
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
|
||||
assert not np.allclose(initial_images, lora_images)
|
||||
assert np.allclose(initial_images, unloaded_lora_images, atol=1e-3)
|
||||
|
||||
# make sure we can load a LoRA again after unloading and they don't have
|
||||
# any undesired effects.
|
||||
@@ -637,7 +603,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
).images
|
||||
lora_images_again = lora_images_again[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(lora_images, lora_images_again, atol=1e-3))
|
||||
assert np.allclose(lora_images, lora_images_again, atol=1e-3)
|
||||
release_memory(pipe)
|
||||
|
||||
def test_not_empty_state_dict(self):
|
||||
@@ -651,7 +617,7 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
lcm_lora = load_file(cached_file)
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertTrue(lcm_lora != {})
|
||||
assert lcm_lora != {}
|
||||
release_memory(pipe)
|
||||
|
||||
def test_load_unload_load_state_dict(self):
|
||||
@@ -666,11 +632,11 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
previous_state_dict = lcm_lora.copy()
|
||||
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
assert lcm_lora == previous_state_dict
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
pipe.load_lora_weights(lcm_lora, adapter_name="lcm")
|
||||
self.assertDictEqual(lcm_lora, previous_state_dict)
|
||||
assert lcm_lora == previous_state_dict
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
|
||||
@@ -14,9 +14,9 @@
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -51,7 +51,7 @@ if is_accelerate_available():
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestSD3LoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -113,19 +113,19 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
lora_filename = "lora_peft_format.safetensors"
|
||||
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in SD3.")
|
||||
@pytest.mark.skip("Not supported in SD3.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@@ -138,17 +138,15 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
@require_big_accelerator
|
||||
class SD3LoraIntegrationTests(unittest.TestCase):
|
||||
class TestSD3LoraIntegration:
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@ import gc
|
||||
import importlib
|
||||
import sys
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
@@ -59,7 +59,7 @@ if is_accelerate_available():
|
||||
from accelerate.utils import release_memory
|
||||
|
||||
|
||||
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
class TestStableDiffusionXLLoRA(PeftLoraLoaderMixinTests):
|
||||
has_two_text_encoders = True
|
||||
pipeline_class = StableDiffusionXLPipeline
|
||||
scheduler_cls = EulerDiscreteScheduler
|
||||
@@ -104,21 +104,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
def output_shape(self):
|
||||
return (1, 64, 64, 3)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@is_flaky
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
super().test_multiple_wrong_adapter_name_raises_error()
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -127,10 +117,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -139,10 +129,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
|
||||
expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
pipe=pipe, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
def test_lora_scale_kwargs_match_fusion(self, base_pipe_output):
|
||||
if torch.cuda.is_available():
|
||||
expected_atol = 9e-2
|
||||
expected_rtol = 9e-2
|
||||
@@ -150,21 +140,21 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
expected_atol = 1e-3
|
||||
expected_rtol = 1e-3
|
||||
|
||||
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
|
||||
super().test_lora_scale_kwargs_match_fusion(
|
||||
base_pipe_output=base_pipe_output, expected_atol=expected_atol, expected_rtol=expected_rtol
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch_accelerator
|
||||
@require_peft_backend
|
||||
class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
class TestLoraSDXLIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def _gc_and_cache_cleanup(self, torch_device):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
yield
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -383,7 +373,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
end_time = time.time()
|
||||
elapsed_time_fusion = end_time - start_time
|
||||
|
||||
self.assertTrue(elapsed_time_fusion < elapsed_time_non_fusion)
|
||||
assert elapsed_time_fusion < elapsed_time_non_fusion
|
||||
|
||||
release_memory(pipe)
|
||||
|
||||
@@ -439,14 +429,14 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
|
||||
for key, value in text_encoder_1_sd.items():
|
||||
key = remap_key(key, fused_te_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
|
||||
assert torch.allclose(fused_te_state_dict[key], value)
|
||||
|
||||
for key, value in text_encoder_2_sd.items():
|
||||
key = remap_key(key, fused_te_2_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
|
||||
assert torch.allclose(fused_te_2_state_dict[key], value)
|
||||
|
||||
for key, value in unet_state_dict.items():
|
||||
self.assertTrue(torch.allclose(unet_state_dict[key], value))
|
||||
assert torch.allclose(unet_state_dict[key], value)
|
||||
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
@@ -589,7 +579,7 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
pipe.load_lora_weights(lora_id, weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
|
||||
assert check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet"
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie"
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
@@ -39,7 +39,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestWanLoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -104,40 +104,40 @@ class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan.")
|
||||
@pytest.mark.skip("Not supported in Wan.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -32,7 +31,6 @@ from ..testing_utils import (
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,7 +45,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
class TestWanVACELoRA(PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
scheduler_kwargs = {}
|
||||
@@ -121,56 +119,51 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_lora_denoiser_fused_multi(self, pipe):
|
||||
super().test_simple_inference_with_text_lora_denoiser_fused_multi(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
|
||||
def test_simple_inference_with_text_denoiser_lora_unfused(self, pipe):
|
||||
super().test_simple_inference_with_text_denoiser_lora_unfused(pipe=pipe, expected_atol=9e-3)
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Not supported in Wan VACE.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
@pytest.mark.skip("Text encoder LoRA is not supported in Wan VACE.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
super().test_layerwise_casting_inference_denoiser()
|
||||
|
||||
@require_peft_version_greater("0.13.2")
|
||||
def test_lora_exclude_modules_wanvace(self):
|
||||
def test_lora_exclude_modules_wanvace(self, base_pipe_output, tmpdirname, pipe):
|
||||
exclude_module_name = "vace_blocks.0.proj_out"
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components).to(torch_device)
|
||||
_, text_lora_config, denoiser_lora_config = self.get_dummy_components()
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
output_no_lora = self.get_base_pipe_output()
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
assert base_pipe_output.shape == self.output_shape
|
||||
|
||||
# only supported for `denoiser` now
|
||||
denoiser_lora_config.target_modules = ["proj_out"]
|
||||
@@ -180,36 +173,30 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
)
|
||||
# The state dict shouldn't contain the modules to be excluded from LoRA.
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
|
||||
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
|
||||
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
|
||||
self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
|
||||
self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
|
||||
# Check in the loaded state dict.
|
||||
loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
assert not any(exclude_module_name in k for k in loaded_state_dict)
|
||||
assert any("proj_out" in k for k in loaded_state_dict)
|
||||
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdir)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
|
||||
self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
|
||||
# Check in the state dict obtained after loading LoRA.
|
||||
pipe.load_lora_weights(tmpdirname)
|
||||
state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
|
||||
assert not any(exclude_module_name in k for k in state_dict_from_model)
|
||||
assert any("proj_out" in k for k in state_dict_from_model)
|
||||
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
|
||||
"LoRA should change outputs.",
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Lora outputs should match.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_and_scale()
|
||||
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
assert not np.allclose(base_pipe_output, output_lora_exclude_modules, atol=1e-3, rtol=1e-3), (
|
||||
"LoRA should change outputs."
|
||||
)
|
||||
assert np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3), (
|
||||
"Lora outputs should match."
|
||||
)
|
||||
|
||||
1658
tests/lora/utils.py
1658
tests/lora/utils.py
File diff suppressed because it is too large
Load Diff
@@ -1,93 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 4.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 4.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||
return
|
||||
@@ -165,6 +165,7 @@ class ModularPipelineTesterMixin:
|
||||
expected_max_diff=1e-4,
|
||||
):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||
|
||||
Reference in New Issue
Block a user