Compare commits

...

28 Commits

Author SHA1 Message Date
Aryan
a1da7752e5 Merge branch 'main' into feature/guiders 2025-04-05 10:55:02 +02:00
Aryan
b30cf5d452 spatio temporal guidance 2025-04-05 10:39:46 +02:00
Edna
41afb6690c Add Wan with STG as a community pipeline (#11184)
* Add stg wan to community pipelines

* remove debug prints

* remove unused comment

* Update doc

* Add credit + fix typo

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-05 04:00:40 +02:00
Aryan
357f4f056b update 2025-04-04 22:44:02 +02:00
Tolga Cangöz
13e48492f0 [LTX0.9.5] Refactor LTXConditionPipeline for text-only conditioning (#11174)
* Refactor `LTXConditionPipeline` to add text-only conditioning

* style

* up

* Refactor `LTXConditionPipeline` to streamline condition handling and improve clarity

* Improve condition checks

* Simplify latents handling based on conditioning type

* Refactor rope_interpolation_scale preparation for clarity and efficiency

* Update LTXConditionPipeline docstring to clarify supported input types

* Add LTX Video 0.9.5 model to documentation

* Clarify documentation to indicate support for text-only conditioning without passing `conditions`

* refactor: comment out unused parameters in LTXConditionPipeline

* fix: restore previously commented parameters in LTXConditionPipeline

* fix: remove unused parameters from LTXConditionPipeline

* refactor: remove unnecessary lines in LTXConditionPipeline
2025-04-04 16:43:15 +02:00
Suprhimp
94f2c48d58 [feat]Add strength in flux_fill pipeline (denoising strength for fluxfill) (#10603)
* [feat]add strength in flux_fill pipeline

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

* [refactor] refactor after review

* [fix] change comment

* Apply style fixes

* empty

* fix

* update prepare_latents from flux.img2img pipeline

* style

* Update src/diffusers/pipelines/flux/pipeline_flux_fill.py

---------
2025-04-04 11:23:30 -03:00
Dhruv Nair
aabf8ce20b Fix Single File loading for LTX VAE (#11200)
update
2025-04-04 18:02:39 +05:30
Kenneth Gerald Hamilton
f10775b1b5 Fixed requests.get function call by adding timeout parameter. (#11156)
* Fixed requests.get function call by adding timeout parameter.

* declare DIFFUSERS_REQUEST_TIMEOUT in constants and import when needed

* remove unneeded os import

* Apply style fixes

---------

Co-authored-by: Sai-Suraj-27 <sai.suraj.27.729@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-04-04 07:23:14 +01:00
Aryan
53b6b9fcb6 perturbed attention guidance 2025-04-04 04:16:46 +02:00
Aryan
46643564a3 refactor 2025-04-04 01:41:34 +02:00
célina
6edb774b5e Update Style Bot workflow (#11202)
update style bot workflow
2025-04-03 19:31:49 +02:00
Basile Lewandowski
480510ada9 Change KolorsPipeline LoRA Loader to StableDiffusion (#11198)
Change LoRA Loader to StableDiffusion

Replace the SDXL LoRA Loader Mixin inheritance with the StableDiffusion one
2025-04-03 11:21:11 -03:00
Aryan
77324c40c4 adaptive projected guidance 2025-04-03 05:01:54 +02:00
Aryan
05d74ef3e7 cfg zero star 2025-04-03 04:21:24 +02:00
Aryan
9997c223a8 more slg improvements 2025-04-03 03:50:30 +02:00
Aryan
d91d10737a update slg docstring 2025-04-03 03:43:10 +02:00
Aryan
5ac7f360b0 skip layer guidance 2025-04-03 03:26:55 +02:00
Aryan
594e8d663f classifier-free guidance 2025-04-03 00:13:15 +02:00
Aryan
c76e1cc17e update 2025-04-02 21:52:33 +02:00
Aryan
315e357a18 Merge branch 'main' into integrations/first-block-cache-2 2025-04-02 01:21:22 +02:00
Aryan
1f33ca276d support flux, ltx i2v, ltx condition 2025-04-02 01:21:09 +02:00
Aryan
41b0c473d2 fix controlnet flux 2025-04-02 01:20:53 +02:00
Aryan
0e232ac8c0 fix hs residual bug for single return outputs; support ltx 2025-04-02 00:38:11 +02:00
Aryan
2557238b4d cache context for different batches of data 2025-04-01 19:40:23 +02:00
Aryan
d71fe55895 update 2025-04-01 17:06:45 +02:00
Aryan
7ab424a15a remove debug logs 2025-04-01 01:39:00 +02:00
Aryan
dd69b41834 modify flux single blocks to make compatible with cache techniques (without too much model-specific intrusion code) 2025-04-01 01:28:09 +02:00
Aryan
406b1062f8 update 2025-03-31 04:27:35 +02:00
52 changed files with 3243 additions and 331 deletions

View File

@@ -13,39 +13,5 @@ jobs:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
with:
python_quality_dependencies: "[quality]"
pre_commit_script_name: "Download and Compare files from the main branch"
pre_commit_script: |
echo "Downloading the files from the main branch"
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
echo "Compare the files and raise error if needed"
diff_failed=0
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_setup.py setup.py; then
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
diff_failed=1
fi
if [ $diff_failed -eq 1 ]; then
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
exit 1
fi
echo "No changes in the files. Proceeding..."
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
style_command: "make style && make quality"
secrets:
bot_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
# Caching methods
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
## Faster Cache
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
pipe.transformer.enable_cache(config)
```
## First Block Cache
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```
## Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
```python
import torch
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
# poorer quality of generated videos.
config = PyramidAttentionBroadcastConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(100, 800),
current_timestep_callback=lambda: pipe.current_timestep,
)
pipe.transformer.enable_cache(config)
```
### CacheMixin
[[autodoc]] CacheMixin
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast
### FasterCacheConfig
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache
### PyramidAttentionBroadcastConfig
[[autodoc]] PyramidAttentionBroadcastConfig
[[autodoc]] apply_pyramid_attention_broadcast

View File

@@ -32,6 +32,7 @@ Available models:
|:-------------:|:-----------------:|
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.

View File

@@ -10,7 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Example | Description | Code Example | Colab | Author |
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/exx8/differential-diffusion) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
@@ -124,7 +124,6 @@ pipe = pipe.to("cuda")
#--------Option--------#
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
stg_applied_layers_idx = [34]
stg_mode = "STG"
stg_scale = 1.0 # 0.0 for CFG
#----------------------#

View File

@@ -0,0 +1,661 @@
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import types
from typing import Any, Callable, Dict, List, Optional, Union
import ftfy
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.loaders import WanLoraLoaderMixin
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
>>> from examples.community.pipeline_stg_wan import WanSTGPipeline
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
>>> pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> # Configure STG mode options
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
>>> stg_scale = 1.0 # Set 0.0 for CFG
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=720,
... width=1280,
... num_frames=81,
... guidance_scale=5.0,
... stg_applied_layers_idx=stg_applied_layers_idx,
... stg_scale=stg_scale,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=16)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
def forward_with_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
return hidden_states
def forward_without_stg(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
r"""
Pipeline for text-to-video generation using Wan.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`WanTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`UniPCMultistepScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
max_sequence_length: int = 226,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 480,
width: int = 832,
num_frames: int = 81,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def do_spatio_temporal_guidance(self):
return self._stg_scale > 0.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Union[str, List[str]] = None,
height: int = 480,
width: int = 832,
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "np",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],
stg_scale: Optional[float] = 0.0,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, defaults to `480`):
The height in pixels of the generated image.
width (`int`, defaults to `832`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `81`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
The dtype to use for the torch.amp.autocast.
Examples:
Returns:
[`~WanPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
torch.float32,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
block.forward = types.MethodType(forward_without_stg, block)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_spatio_temporal_guidance:
for idx, block in enumerate(self.transformer.blocks):
if idx in stg_applied_layers_idx:
block.forward = types.MethodType(forward_with_stg, block)
noise_perturb = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = (
noise_uncond
+ guidance_scale * (noise_pred - noise_uncond)
+ self._stg_scale * (noise_pred - noise_perturb)
)
else:
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
latents = latents / latents_std + latents_mean
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
video = latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return WanPipelineOutput(frames=video)

View File

@@ -49,6 +49,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -418,7 +419,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image

View File

@@ -54,6 +54,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -475,7 +476,7 @@ def convert_to_np(image, resolution):
def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image

View File

@@ -59,6 +59,7 @@ from diffusers.schedulers import (
UnCLIPScheduler,
)
from diffusers.utils import is_accelerate_available, logging
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
if is_accelerate_available():
@@ -1435,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()

View File

@@ -11,6 +11,7 @@ from diffusion import sampling
from torch import nn
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
MODELS_MAP = {
@@ -74,7 +75,7 @@ class DiffusionUncond(nn.Module):
def download(model_name):
url = MODELS_MAP[model_name]["url"]
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:

View File

@@ -13,6 +13,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
renew_vae_attention_paths,
renew_vae_resnet_paths,
)
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
@@ -122,7 +123,8 @@ def vae_pt_to_vae_diffuser(
):
# Only support V1
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
timeout=DIFFUSERS_REQUEST_TIMEOUT,
)
io_obj = io.BytesIO(r.content)

View File

@@ -33,6 +33,7 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
@@ -129,12 +130,25 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -708,11 +722,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (

View File

@@ -0,0 +1,24 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance

View File

@@ -0,0 +1,145 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class AdaptiveProjectedGuidance(GuidanceMixin):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, *args):
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return super().prepare_inputs(*args)
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred

View File

@@ -0,0 +1,98 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeGuidance(GuidanceMixin):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
):
super().__init__()
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)

View File

@@ -0,0 +1,110 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.type_as(cond)

View File

@@ -0,0 +1,213 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from ..utils import deprecate, get_logger
if TYPE_CHECKING:
from ..models.attention_processor import AttentionProcessor
logger = get_logger(__name__) # pylint: disable=invalid-name
class GuidanceMixin:
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
def __init__(self):
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._preds: Dict[str, torch.Tensor] = {}
self._num_outputs_prepared: int = 0
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._preds = {}
self._num_outputs_prepared = 0
def prepare_models(self, denoiser: torch.nn.Module) -> None:
pass
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
# Alternating conditional and unconditional inputs as batches
inputs = [arg[i % 2] for i in range(num_conditions)]
list_of_inputs.append(inputs)
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
pass
def __call__(self, **kwargs) -> Any:
if len(kwargs) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
)
return self.forward(**kwargs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
@property
def num_conditions(self) -> int:
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
@property
def outputs(self) -> Dict[str, torch.Tensor]:
return self._preds
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
def _replace_attention_processors(
module: torch.nn.Module,
pag_applied_layers: Optional[Union[str, List[str]]] = None,
skip_context_attention: bool = False,
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
metadata_name: Optional[str] = None,
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
if processors is not None and metadata_name is not None:
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
if metadata_name is not None:
if isinstance(pag_applied_layers, str):
pag_applied_layers = [pag_applied_layers]
return _replace_layers_with_guidance_processors(
module, pag_applied_layers, skip_context_attention, metadata_name
)
if processors is not None:
_replace_layers_with_existing_processors(processors)
def _replace_layers_with_guidance_processors(
module: torch.nn.Module,
pag_applied_layers: List[str],
skip_context_attention: bool,
metadata_name: str,
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
from ..hooks._common import _ATTENTION_CLASSES
from ..hooks._helpers import GuidanceMetadataRegistry
processors = []
for name, submodule in module.named_modules():
if (
(not isinstance(submodule, _ATTENTION_CLASSES))
or (getattr(submodule, "processor", None) is None)
or not (
any(
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
for pag_layer in pag_applied_layers
)
)
):
continue
old_attention_processor = submodule.processor
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
new_attention_processor_cls = getattr(metadata, metadata_name)
new_attention_processor = new_attention_processor_cls()
# !!! dunder methods cannot be replaced on instances !!!
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
# new_attention_processor.__call__ = partial(
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
# )
submodule.processor = new_attention_processor
processors.append((submodule, old_attention_processor))
return processors
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
for module, proc in processors:
module.processor = proc
def _is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
def _raise_guidance_deprecation_warning(
*,
guidance_scale: Optional[float] = None,
guidance_rescale: Optional[float] = None,
) -> None:
if guidance_scale is not None:
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
if guidance_rescale is not None:
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)

View File

@@ -0,0 +1,180 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
class PerturbedAttentionGuidance(GuidanceMixin):
"""
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
Args:
pag_applied_layers (`str` or `List[str]`):
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
pag_scale (`float`, defaults to `3.0`):
The scale parameter for perturbed attention guidance.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
def __init__(
self,
pag_applied_layers: Union[str, List[str]],
guidance_scale: float = 7.5,
pag_scale: float = 3.0,
skip_context_attention: bool = False,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.pag_applied_layers = pag_applied_layers
self.guidance_scale = guidance_scale
self.pag_scale = pag_scale
self.skip_context_attention = skip_context_attention
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self._is_pag_batch = False
self._original_processors = None
self._denoiser = None
def prepare_models(self, denoiser: torch.nn.Module):
self._denoiser = denoiser
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_pag_enabled():
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_perturbed"
self._preds[key] = pred
# Restore the original attention processors if previously replaced
if self._is_pag_batch:
_replace_attention_processors(self._denoiser, processors=self._original_processors)
self._is_pag_batch = False
self._original_processors = None
# Prepare denoiser for perturbed attention prediction if needed
if self._is_pag_enabled():
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
)
if should_register_pag:
self._is_pag_batch = True
self._original_processors = _replace_attention_processors(
self._denoiser,
self.pag_applied_layers,
skip_context_attention=self.skip_context_attention,
metadata_name="perturbed_attention_guidance_processor_cls",
)
def cleanup_models(self, denoiser: torch.nn.Module):
self._denoiser = None
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_perturbed: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_pag_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_perturbed
pred = pred_cond + self.pag_scale * shift
elif not self._is_pag_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_perturbed = pred_cond - pred_perturbed
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_pag_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_pag_enabled(self) -> bool:
is_zero = math.isclose(self.pag_scale, 0.0)
return not is_zero

View File

@@ -0,0 +1,235 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import GuidanceMixin, rescale_noise_cfg
class SkipLayerGuidance(GuidanceMixin):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (0.0 < skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module):
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
# Register the hooks for layer skipping if the step is within the specified range
if skip_start_step < self._step < skip_stop_step:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
num_conditions = self.num_conditions
list_of_inputs = []
for arg in args:
if isinstance(arg, torch.Tensor):
list_of_inputs.append([arg] * num_conditions)
elif isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError(
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
f"with the first element being the conditional input and the second element being the unconditional input or None."
)
if arg[1] is None:
# Only conditioning inputs for all batches
list_of_inputs.append([arg[0]] * num_conditions)
else:
list_of_inputs.append([arg[0], arg[1], arg[0]])
else:
raise ValueError(
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
)
return tuple(list_of_inputs)
def prepare_outputs(self, pred: torch.Tensor) -> None:
self._num_outputs_prepared += 1
if self._num_outputs_prepared > self.num_conditions:
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
key = self._input_predictions[self._num_outputs_prepared - 1]
if not self._is_cfg_enabled() and self._is_slg_enabled():
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
# to avoid writing into pred_uncond which is not used
if self._num_outputs_prepared == 2:
key = "pred_cond_skip"
self._preds[key] = pred
def cleanup_models(self, denoiser: torch.nn.Module):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)
def _is_slg_enabled(self) -> bool:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -1,9 +1,25 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast

View File

@@ -0,0 +1,32 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)

View File

@@ -0,0 +1,276 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import (
CogView4AttnProcessor,
CogView4PAGAttnProcessor,
CogView4TransformerBlock,
)
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class GuidanceMetadata:
perturbed_attention_guidance_processor_cls: Type = None
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class GuidanceMetadataRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: GuidanceMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> GuidanceMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# CogView4
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_guidance_metadata():
# CogView4
GuidanceMetadataRegistry.register(
model_class=CogView4AttnProcessor,
metadata=GuidanceMetadata(
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_guidance_metadata()
_register_transformer_blocks_metadata()

View File

@@ -0,0 +1,223 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseMarkedState, HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
poorer generation quality. A lower threshold may not result in significant generation speedup. The
threshold is compared against the absmean difference of the residuals between the current and cached
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseMarkedState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
self.shared_state = shared_state
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hs_residual = output[self._metadata.return_hidden_states_index] - original_hs
else:
hs_residual = output - original_hs
hs, ehs = None, None
should_compute = self._should_compute_remaining_blocks(hs_residual)
self.shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
else:
hs = self.shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
ehs = (
self.shared_state.tail_block_residuals[1]
+ output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hs
return_output[self._metadata.return_encoder_hidden_states_index] = ehs
return_output = tuple(return_output)
else:
return_output = hs
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
self.shared_state.head_block_output = head_block_output
self.shared_state.head_block_residual = hs_residual
return output
def reset_state(self, module):
self.shared_state.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
if self.shared_state.head_block_residual is None:
return True
prev_hs_residual = self.shared_state.head_block_residual
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
prev_hs_mean = prev_hs_residual.abs().mean()
diff = (hs_absmean / prev_hs_mean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
super().__init__()
self.shared_state = shared_state
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
if not isinstance(outputs_if_skipped, tuple):
outputs_if_skipped = (outputs_if_skipped,)
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
original_ehs = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
if self.shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hs_residual, ehs_residual = None, None
if isinstance(output, tuple):
hs_residual = (
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
)
ehs_residual = (
output[self._metadata.return_encoder_hidden_states_index]
- self.shared_state.head_block_output[1]
)
else:
hs_residual = output - self.shared_state.head_block_output
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
return output
output_count = len(outputs_if_skipped)
if output_count == 1:
return_output = original_hs
else:
return_output = [None] * output_count
return_output[self._metadata.return_hidden_states_index] = original_hs
return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
shared_state = FBCSharedBlockState()
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Apply FBCBlockHook to '{name}'")
apply_fbc_block_hook(block, shared_state)
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)

View File

@@ -18,11 +18,76 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class BaseMarkedState(BaseState):
def __init__(self, init_args=None, init_kwargs=None):
super().__init__()
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._mark_name = None
self._state_cache = {}
def get_current_state(self) -> "BaseMarkedState":
if self._mark_name is None:
# If no mark name is set, simply return a dummy object since we're not going to be using it
return self
if self._mark_name not in self._state_cache.keys():
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
return self._state_cache[self._mark_name]
def mark_state(self, name: str) -> None:
self._mark_name = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._mark_name = None
def __getattribute__(self, name):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
return object.__getattribute__(self, name)
else:
current_state = BaseMarkedState.get_current_state(self)
return object.__getattribute__(current_state, name)
def __setattr__(self, name, value):
if name in (
"get_current_state",
"mark_state",
"reset",
"_init_args",
"_init_kwargs",
"_mark_name",
"_state_cache",
) or _is_dunder_method(name):
object.__setattr__(self, name, value)
else:
current_state = BaseMarkedState.get_current_state(self)
object.__setattr__(current_state, name, value)
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +164,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, BaseMarkedState):
attr.mark_state(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +284,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +297,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _mark_state(self, name: str) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._mark_state(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._mark_state(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
@@ -234,3 +321,7 @@ class HookRegistry:
if i < len(self._hook_order) - 1:
registry_repr += "\n"
return f"HookRegistry(\n{registry_repr}\n)"
def _is_dunder_method(name: str) -> bool:
return name.startswith("__") and name.endswith("__") and name in dir(object)

View File

@@ -0,0 +1,182 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __init__(self) -> None:
super().__init__()
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
with AttentionScoreSkipFunctionMode():
return self.fn_ref.original_forward(*args, **kwargs)
else:
return self.skip_processor_output_fn(module, *args, **kwargs)
class FeedForwardSkipHook(ModelHook):
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
return output
class TransformerBlockSkipHook(ModelHook):
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = getattr(module, config.fqn, None)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook()
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
registry.register_hook(hook, name)
elif config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook()
registry.register_hook(hook, name)
else:
raise ValueError(
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)

View File

@@ -44,6 +44,7 @@ from ..utils import (
is_transformers_available,
logging,
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
@@ -443,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
"Please provide a valid local file path."
)
original_config_file = BytesIO(requests.get(original_config_file).content)
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
@@ -2406,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"timestep_scale_multiplier": remove_keys_,
}
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,21 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def _cache_context(self):
r"""Context manager that provides additional methods for cache management."""
cache_context = _CacheContextManager(self)
yield cache_context
class _CacheContextManager:
def __init__(self, model: CacheMixin):
self.model = model
def mark_state(self, name: str) -> None:
from ..hooks import HookRegistry
if self.model.is_cache_enabled:
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry._mark_state(name)

View File

@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()

View File

@@ -460,3 +460,84 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
### ===== Custom attention processors for guidance methods ===== ###
class CogView4PAGAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
skip_context_attention: bool = False,
) -> torch.Tensor:
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
# 4. Attention
if skip_context_attention:
hidden_states = value
else:
# PAG uses a custom attention mask for perturbed attention path:
# - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens
# - Set diagonal to `0.0` for attention between image tokens
seq_length = text_seq_length + image_seq_length
perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf"))
perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0
perturbed_attention_mask.fill_diagonal_(0.0)
perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states

View File

@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -508,20 +513,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -531,12 +537,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

View File

@@ -21,6 +21,7 @@ import torch
from transformers import AutoTokenizer, GlmModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, CogView4Transformer2DModel
@@ -426,6 +427,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 1024,
guidance: Optional[GuidanceMixin] = None,
) -> Union[CogView4PipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -514,6 +516,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
if guidance is None:
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
@@ -608,46 +614,47 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
self._current_timestep = t
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
guidance.prepare_models(self.transformer)
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
latents,
(prompt_embeds[0], negative_prompt_embeds[0]),
original_size[0],
target_size[0],
crops_coords_top_left[0],
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
):
cc.mark_state(f"batch_{batch_index}")
latent = latent.to(transformer_dtype)
timestep = t.expand(latent.shape[0])
noise_pred = self.transformer(
hidden_states=latent,
encoder_hidden_states=condition,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
original_size=original_size_c,
target_size=target_size_c,
crop_coords=crop_coord_c,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
guidance.prepare_outputs(noise_pred)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
outputs = guidance.outputs
noise_pred = guidance(**outputs)
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
guidance.cleanup_models(self.transformer)
# call the callback, if provided
if callback_on_step_end is not None:
@@ -656,8 +663,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
negative_prompt_embeds = [
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
]
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

View File

@@ -906,7 +906,7 @@ class FluxPipeline(
)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -917,6 +917,7 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
@@ -932,6 +933,8 @@ class FluxPipeline(
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
cc.mark_state("uncond")
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,

View File

@@ -224,11 +224,13 @@ class FluxFillPipeline(
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
vae_latent_channels=latent_channels,
vae_latent_channels=self.latent_channels,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
@@ -493,10 +495,38 @@ class FluxFillPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=None,
@@ -507,6 +537,9 @@ class FluxFillPipeline(
mask_image=None,
masked_image_latents=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
@@ -624,9 +657,11 @@ class FluxFillPipeline(
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
@@ -636,28 +671,41 @@ class FluxFillPipeline(
generator,
latents=None,
):
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels:
image_latents = self._encode_vae_image(image=image, generator=generator)
else:
image_latents = image
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids
@property
@@ -687,6 +735,7 @@ class FluxFillPipeline(
masked_image_latents: Optional[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 30.0,
@@ -731,6 +780,12 @@ class FluxFillPipeline(
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
@@ -794,6 +849,7 @@ class FluxFillPipeline(
self.check_inputs(
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=prompt_embeds,
@@ -809,6 +865,9 @@ class FluxFillPipeline(
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
@@ -838,47 +897,9 @@ class FluxFillPipeline(
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare mask and masked image latents
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
image = self.image_processor.preprocess(image, height=height, width=width)
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
height, width = image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
# 6. Prepare timesteps
# 4. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
@@ -893,6 +914,54 @@ class FluxFillPipeline(
sigmas=sigmas,
mu=mu,
)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare mask and masked image latents
if masked_image_latents is not None:
masked_image_latents = masked_image_latents.to(latents.device)
else:
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
masked_image = init_image * (1 - mask_image)
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
height, width = init_image.shape[-2:]
mask, masked_image_latents = self.prepare_mask_latents(
mask_image,
masked_image,
batch_size,
num_channels_latents,
num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
)
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)

View File

@@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -693,6 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -705,6 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
)[0]
if do_true_cfg:
cc.mark_state("uncond")
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,

View File

@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
@@ -121,7 +121,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
@@ -129,8 +129,8 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:

View File

@@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -719,6 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -14,7 +14,7 @@
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import PIL.Image
import torch
@@ -75,6 +75,7 @@ EXAMPLE_DOC_STRING = """
>>> # Generate video
>>> generator = torch.Generator("cuda").manual_seed(0)
>>> # Text-only conditioning is also supported without the need to pass `conditions`
>>> video = pipe(
... conditions=[condition1, condition2],
... prompt=prompt,
@@ -223,7 +224,7 @@ def retrieve_latents(
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
r"""
Pipeline for image-to-video generation.
Pipeline for text/image/video-to-video generation.
Reference: https://github.com/Lightricks/LTX-Video
@@ -482,9 +483,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if conditions is not None and (image is not None or video is not None):
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
if conditions is None and (image is None and video is None):
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
if conditions is None:
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
raise ValueError(
@@ -642,9 +640,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
def prepare_latents(
self,
conditions: List[torch.Tensor],
condition_strength: List[float],
condition_frame_index: List[int],
conditions: Optional[List[torch.Tensor]] = None,
condition_strength: Optional[List[float]] = None,
condition_frame_index: Optional[List[int]] = None,
batch_size: int = 1,
num_channels_latents: int = 128,
height: int = 512,
@@ -654,7 +652,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
@@ -662,77 +660,80 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
if len(conditions) > 0:
condition_latent_frames_mask = torch.zeros(
(batch_size, num_latent_frames), device=device, dtype=torch.float32
)
extra_conditioning_latents = []
extra_conditioning_video_ids = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)
extra_conditioning_latents = []
extra_conditioning_video_ids = []
extra_conditioning_mask = []
extra_conditioning_num_latents = 0
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
condition_latents = self._normalize_latents(
condition_latents, self.vae.latents_mean, self.vae.latents_std
).to(device, dtype=dtype)
num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)
num_data_frames = data.size(2)
num_cond_frames = condition_latents.size(2)
if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength
if frame_index == 0:
latents[:, :, :num_cond_frames] = torch.lerp(
latents[:, :, :num_cond_frames], condition_latents, strength
)
condition_latent_frames_mask[:, :num_cond_frames] = strength
else:
if num_data_frames > 1:
if num_cond_frames < num_prefix_latent_frames:
raise ValueError(
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
)
else:
if num_data_frames > 1:
if num_cond_frames < num_prefix_latent_frames:
raise ValueError(
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
)
if num_cond_frames > num_prefix_latent_frames:
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp(
latents[:, :, start_frame:end_frame],
condition_latents[:, :, num_prefix_latent_frames:],
strength,
)
condition_latent_frames_mask[:, start_frame:end_frame] = strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
if num_cond_frames > num_prefix_latent_frames:
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
latents[:, :, start_frame:end_frame] = torch.lerp(
latents[:, :, start_frame:end_frame],
condition_latents[:, :, num_prefix_latent_frames:],
strength,
)
condition_latent_frames_mask[:, start_frame:end_frame] = strength
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength)
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
condition_latents = torch.lerp(noise, condition_latents, strength)
condition_video_ids = self._prepare_video_ids(
batch_size,
condition_latents.size(2),
latent_height,
latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)
condition_video_ids = self._prepare_video_ids(
batch_size,
condition_latents.size(2),
latent_height,
latent_width,
patch_size=self.transformer_spatial_patch_size,
patch_size_t=self.transformer_temporal_patch_size,
device=device,
)
condition_video_ids = self._scale_video_ids(
condition_video_ids,
scale_factor=self.vae_spatial_compression_ratio,
scale_factor_t=self.vae_temporal_compression_ratio,
frame_index=frame_index,
device=device,
)
condition_latents = self._pack_latents(
condition_latents,
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
condition_conditioning_mask = torch.full(
condition_latents.shape[:2], strength, device=device, dtype=dtype
)
extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1)
extra_conditioning_latents.append(condition_latents)
extra_conditioning_video_ids.append(condition_video_ids)
extra_conditioning_mask.append(condition_conditioning_mask)
extra_conditioning_num_latents += condition_latents.size(1)
video_ids = self._prepare_video_ids(
batch_size,
@@ -743,7 +744,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
patch_size=self.transformer_spatial_patch_size,
device=device,
)
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
if len(conditions) > 0:
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
else:
conditioning_mask, extra_conditioning_num_latents = None, 0
video_ids = self._scale_video_ids(
video_ids,
scale_factor=self.vae_spatial_compression_ratio,
@@ -755,7 +759,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
)
if len(extra_conditioning_latents) > 0:
if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
@@ -955,7 +959,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
frame_index = [condition.frame_index for condition in conditions]
image = [condition.image for condition in conditions]
video = [condition.video for condition in conditions]
else:
elif image is not None or video is not None:
if not isinstance(image, list):
image = [image]
num_conditions = 1
@@ -999,32 +1003,34 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
vae_dtype = self.vae.dtype
conditioning_tensors = []
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
image, video, frame_index, strength
):
if condition_image is not None:
condition_tensor = (
self.video_processor.preprocess(condition_image, height, width)
.unsqueeze(2)
.to(device, dtype=vae_dtype)
)
elif condition_video is not None:
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
num_frames_input = condition_tensor.size(2)
num_frames_output = self.trim_conditioning_sequence(
condition_frame_index, num_frames_input, num_frames
)
condition_tensor = condition_tensor[:, :, :num_frames_output]
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
is_conditioning_image_or_video = image is not None or video is not None
if is_conditioning_image_or_video:
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
image, video, frame_index, strength
):
if condition_image is not None:
condition_tensor = (
self.video_processor.preprocess(condition_image, height, width)
.unsqueeze(2)
.to(device, dtype=vae_dtype)
)
elif condition_video is not None:
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
num_frames_input = condition_tensor.size(2)
num_frames_output = self.trim_conditioning_sequence(
condition_frame_index, num_frames_input, num_frames
)
condition_tensor = condition_tensor[:, :, :num_frames_output]
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
else:
raise ValueError("Either `image` or `video` must be provided for conditioning.")
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
raise ValueError(
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
f"but got {condition_tensor.size(2)} frames."
)
conditioning_tensors.append(condition_tensor)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
@@ -1045,7 +1051,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
video_coords = video_coords.float()
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
init_latents = latents.clone()
init_latents = latents.clone() if is_conditioning_image_or_video else None
if self.do_classifier_free_guidance:
video_coords = torch.cat([video_coords, video_coords], dim=0)
@@ -1065,15 +1071,15 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
if image_cond_noise_scale > 0:
if image_cond_noise_scale > 0 and init_latents is not None:
# Add timestep-dependent noise to the hard-conditioning latents
# This helps with motion continuity, especially when conditioned on a single frame
latents = self.add_noise_to_image_conditioning_latents(
@@ -1086,17 +1092,20 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
)
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
if is_conditioning_image_or_video:
conditioning_mask_model_input = (
torch.cat([conditioning_mask, conditioning_mask])
if self.do_classifier_free_guidance
else conditioning_mask
)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
@@ -1115,8 +1124,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
denoised_latents = self.scheduler.step(
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
)[0]
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
if is_conditioning_image_or_video:
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
else:
latents = denoised_latents
if callback_on_step_end is not None:
callback_kwargs = {}
@@ -1134,7 +1146,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if XLA_AVAILABLE:
xm.mark_step()
latents = latents[:, extra_conditioning_num_latents:]
if is_conditioning_image_or_video:
latents = latents[:, extra_conditioning_num_latents:]
latents = self._unpack_latents(
latents,
latent_num_frames,

View File

@@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
)
# 7. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -792,6 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
cc.mark_state("cond_uncond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,

View File

@@ -52,6 +52,7 @@ from ...schedulers import (
UnCLIPScheduler,
)
from ...utils import is_accelerate_available, logging
from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
from ..paint_by_example import PaintByExampleImageEncoder
from ..pipeline_utils import DiffusionPipeline
@@ -1324,7 +1325,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()

View File

@@ -519,7 +519,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
@@ -528,6 +528,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
cc.mark_state("cond")
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
@@ -537,6 +538,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
)[0]
if self.do_classifier_free_guidance:
cc.mark_state("uncond")
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,

View File

@@ -40,6 +40,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -2,6 +2,81 @@
from ..utils import DummyObject, requires_backends
class AdaptiveProjectedGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PerturbedAttentionGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SkipLayerGuidance(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -17,6 +92,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FirstBlockCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +122,21 @@ class HookRegistry(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LayerSkipConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +156,14 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_layer_skip(*args, **kwargs):
requires_backends(apply_layer_skip, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -7,6 +7,7 @@ import PIL.Image
import PIL.ImageOps
import requests
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .import_utils import BACKENDS_MAPPING, is_imageio_available
@@ -29,7 +30,7 @@ def load_image(
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:

View File

@@ -26,6 +26,7 @@ import requests
from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -594,7 +595,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
# local_path can be passed to correct images of tests
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
elif arry.startswith("http://") or arry.startswith("https://"):
response = requests.get(arry)
response = requests.get(arry, timeout=DIFFUSERS_REQUEST_TIMEOUT)
response.raise_for_status()
arry = np.load(BytesIO(response.content))
elif os.path.isfile(arry):
@@ -615,7 +616,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
def load_pt(url: str, map_location: str):
response = requests.get(url)
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
response.raise_for_status()
arry = torch.load(BytesIO(response.content), map_location=map_location)
return arry
@@ -634,7 +635,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:

View File

@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).

View File

@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -44,7 +45,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}

View File

@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
num_layers=num_layers,
caption_channels=32,
)

View File

@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
class MochiPipelineFastTests(
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS

View File

@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -2631,7 +2632,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten()
output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2738,6 +2739,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
# that will not satisfy the expected properties of the denoiser for caching to be effective
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FirstBlockCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.first_block_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
), "FirstBlockCache outputs should not differ much."
assert np.allclose(
original_image_slice, image_slice_fbc_disabled, atol=1e-4
), "Outputs from normal inference and after disabling cache should not differ."
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.

View File

@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
from packaging.version import parse
from ..src.diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
# GitHub repository details
USER = "huggingface"
@@ -27,7 +30,11 @@ def fetch_all_branches(user, repo):
page = 1 # Start from first page
while True:
# Make a request to the GitHub API for the branches
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page})
response = requests.get(
f"https://api.github.com/repos/{user}/{repo}/branches",
params={"page": page},
timeout=DIFFUSERS_REQUEST_TIMEOUT,
)
# Check if the request was successful
if response.status_code == 200:

View File

@@ -17,6 +17,8 @@ import os
import requests
from ..src.diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
# Configuration
LIBRARY_NAME = "diffusers"
@@ -26,7 +28,7 @@ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
def check_pypi_for_latest_release(library_name):
"""Check PyPI for the latest release of the library."""
response = requests.get(f"https://pypi.org/pypi/{library_name}/json")
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=DIFFUSERS_REQUEST_TIMEOUT)
if response.status_code == 200:
data = response.json()
return data["info"]["version"]
@@ -38,7 +40,7 @@ def check_pypi_for_latest_release(library_name):
def get_github_release_info(github_repo):
"""Fetch the latest release info from GitHub."""
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
response = requests.get(url)
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
if response.status_code == 200:
data = response.json()