mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
28 Commits
update-sty
...
feature/gu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1da7752e5 | ||
|
|
b30cf5d452 | ||
|
|
41afb6690c | ||
|
|
357f4f056b | ||
|
|
13e48492f0 | ||
|
|
94f2c48d58 | ||
|
|
aabf8ce20b | ||
|
|
f10775b1b5 | ||
|
|
53b6b9fcb6 | ||
|
|
46643564a3 | ||
|
|
6edb774b5e | ||
|
|
480510ada9 | ||
|
|
77324c40c4 | ||
|
|
05d74ef3e7 | ||
|
|
9997c223a8 | ||
|
|
d91d10737a | ||
|
|
5ac7f360b0 | ||
|
|
594e8d663f | ||
|
|
c76e1cc17e | ||
|
|
315e357a18 | ||
|
|
1f33ca276d | ||
|
|
41b0c473d2 | ||
|
|
0e232ac8c0 | ||
|
|
2557238b4d | ||
|
|
d71fe55895 | ||
|
|
7ab424a15a | ||
|
|
dd69b41834 | ||
|
|
406b1062f8 |
34
.github/workflows/pr_style_bot.yml
vendored
34
.github/workflows/pr_style_bot.yml
vendored
@@ -13,39 +13,5 @@ jobs:
|
||||
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
|
||||
with:
|
||||
python_quality_dependencies: "[quality]"
|
||||
pre_commit_script_name: "Download and Compare files from the main branch"
|
||||
pre_commit_script: |
|
||||
echo "Downloading the files from the main branch"
|
||||
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
|
||||
curl -o main_setup.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/setup.py
|
||||
curl -o main_check_doc_toc.py https://raw.githubusercontent.com/huggingface/diffusers/refs/heads/main/utils/check_doc_toc.py
|
||||
|
||||
echo "Compare the files and raise error if needed"
|
||||
|
||||
diff_failed=0
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_setup.py setup.py; then
|
||||
echo "Error: The setup.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if ! diff -q main_check_doc_toc.py utils/check_doc_toc.py; then
|
||||
echo "Error: The utils/check_doc_toc.py has changed. Please ensure it matches the main branch."
|
||||
diff_failed=1
|
||||
fi
|
||||
|
||||
if [ $diff_failed -eq 1 ]; then
|
||||
echo "❌ Error happened as we detected changes in the files that should not be changed ❌"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No changes in the files. Proceeding..."
|
||||
rm -rf main_Makefile main_setup.py main_check_doc_toc.py
|
||||
style_command: "make style && make quality"
|
||||
secrets:
|
||||
bot_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
@@ -11,33 +11,6 @@ specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# Caching methods
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Faster Cache
|
||||
|
||||
[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
|
||||
@@ -65,18 +38,68 @@ config = FasterCacheConfig(
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## First Block Cache
|
||||
|
||||
[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
|
||||
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
|
||||
config = FirstBlockCacheConfig(threshold=0.07)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states are not very different between successive inference steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastConfig`] on any pipeline. For some benchmarks, refer to [this](https://github.com/huggingface/diffusers/pull/9562) pull request.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Increasing the value of `spatial_attention_timestep_skip_range[0]` or decreasing the value of
|
||||
# `spatial_attention_timestep_skip_range[1]` will decrease the interval in which pyramid attention
|
||||
# broadcast is active, leader to slower inference speeds. However, large intervals can lead to
|
||||
# poorer quality of generated videos.
|
||||
config = PyramidAttentionBroadcastConfig(
|
||||
spatial_attention_block_skip_range=2,
|
||||
spatial_attention_timestep_skip_range=(100, 800),
|
||||
current_timestep_callback=lambda: pipe.current_timestep,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
### CacheMixin
|
||||
|
||||
[[autodoc]] CacheMixin
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
### FasterCacheConfig
|
||||
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
### PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] PyramidAttentionBroadcastConfig
|
||||
|
||||
[[autodoc]] apply_pyramid_attention_broadcast
|
||||
|
||||
@@ -32,6 +32,7 @@ Available models:
|
||||
|:-------------:|:-----------------:|
|
||||
| [`LTX Video 0.9.0`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.1`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.1.safetensors) | `torch.bfloat16` |
|
||||
| [`LTX Video 0.9.5`](https://huggingface.co/Lightricks/LTX-Video/blob/main/ltx-video-2b-v0.9.5.safetensors) | `torch.bfloat16` |
|
||||
|
||||
Note: The recommended dtype is for the transformer component. The VAE and text encoders can be either `torch.float32`, `torch.bfloat16` or `torch.float16` but the recommended dtype is `torch.bfloat16` as used in the original repository.
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
|
||||
|
||||
| Example | Description | Code Example | Colab | Author |
|
||||
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
|
||||
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/)|
|
||||
|Spatiotemporal Skip Guidance (STG)|[Spatiotemporal Skip Guidance for Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664) (CVPR 2025) enhances video diffusion models by generating a weaker model through layer skipping and using it as guidance, improving fidelity in models like HunyuanVideo, LTXVideo, and Mochi.|[Spatiotemporal Skip Guidance](#spatiotemporal-skip-guidance)|-|[Junha Hyung](https://junhahyung.github.io/), [Kinam Kim](https://kinam0252.github.io/), and [Ednaordinary](https://github.com/Ednaordinary)|
|
||||
|Adaptive Mask Inpainting|Adaptive Mask Inpainting algorithm from [Beyond the Contact: Discovering Comprehensive Affordance for 3D Objects from Pre-trained 2D Diffusion Models](https://github.com/snuvclab/coma) (ECCV '24, Oral) provides a way to insert human inside the scene image without altering the background, by inpainting with adapting mask.|[Adaptive Mask Inpainting](#adaptive-mask-inpainting)|-|[Hyeonwoo Kim](https://sshowbiz.xyz),[Sookwan Han](https://jellyheadandrew.github.io)|
|
||||
|Flux with CFG|[Flux with CFG](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md) provides an implementation of using CFG in [Flux](https://blackforestlabs.ai/announcing-black-forest-labs/).|[Flux with CFG](#flux-with-cfg)|[Notebook](https://github.com/huggingface/notebooks/blob/main/diffusers/flux_with_cfg.ipynb)|[Linoy Tsaban](https://github.com/linoytsaban), [Apolinário](https://github.com/apolinario), and [Sayak Paul](https://github.com/sayakpaul)|
|
||||
|Differential Diffusion|[Differential Diffusion](https://github.com/exx8/differential-diffusion) modifies an image according to a text prompt, and according to a map that specifies the amount of change in each region.|[Differential Diffusion](#differential-diffusion)|[](https://huggingface.co/spaces/exx8/differential-diffusion) [](https://colab.research.google.com/github/exx8/differential-diffusion/blob/main/examples/SD2.ipynb)|[Eran Levin](https://github.com/exx8) and [Ohad Fried](https://www.ohadf.com/)|
|
||||
@@ -124,7 +124,6 @@ pipe = pipe.to("cuda")
|
||||
#--------Option--------#
|
||||
prompt = "A close-up of a beautiful woman's face with colored powder exploding around her, creating an abstract splash of vibrant hues, realistic style."
|
||||
stg_applied_layers_idx = [34]
|
||||
stg_mode = "STG"
|
||||
stg_scale = 1.0 # 0.0 for CFG
|
||||
#----------------------#
|
||||
|
||||
|
||||
661
examples/community/pipeline_stg_wan.py
Normal file
661
examples/community/pipeline_stg_wan.py
Normal file
@@ -0,0 +1,661 @@
|
||||
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from diffusers.loaders import WanLoraLoaderMixin
|
||||
from diffusers.models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.video_processor import VideoProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import export_to_video
|
||||
>>> from diffusers import AutoencoderKLWan
|
||||
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
|
||||
>>> from examples.community.pipeline_stg_wan import WanSTGPipeline
|
||||
|
||||
>>> # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
|
||||
>>> model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
||||
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
>>> pipe = WanSTGPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
||||
>>> flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
|
||||
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
|
||||
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
||||
|
||||
>>> # Configure STG mode options
|
||||
>>> stg_applied_layers_idx = [8] # Layer indices from 0 to 39 for 14b or 0 to 29 for 1.3b
|
||||
>>> stg_scale = 1.0 # Set 0.0 for CFG
|
||||
|
||||
>>> output = pipe(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... height=720,
|
||||
... width=1280,
|
||||
... num_frames=81,
|
||||
... guidance_scale=5.0,
|
||||
... stg_applied_layers_idx=stg_applied_layers_idx,
|
||||
... stg_scale=stg_scale,
|
||||
... ).frames[0]
|
||||
>>> export_to_video(output, "output.mp4", fps=16)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
def forward_with_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward_without_stg(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class WanSTGPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-video generation using Wan.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
|
||||
Args:
|
||||
tokenizer ([`T5Tokenizer`]):
|
||||
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
||||
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
||||
transformer ([`WanTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
scheduler ([`UniPCMultistepScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: AutoTokenizer,
|
||||
text_encoder: UMT5EncoderModel,
|
||||
transformer: WanTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_videos_per_prompt: int = 1,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_videos_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 226,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif negative_prompt is not None and (
|
||||
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
||||
):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int = 16,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_latent_frames,
|
||||
int(height) // self.vae_scale_factor_spatial,
|
||||
int(width) // self.vae_scale_factor_spatial,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return self._stg_scale > 0.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: Union[str, List[str]] = None,
|
||||
height: int = 480,
|
||||
width: int = 832,
|
||||
num_frames: int = 81,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 5.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "np",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
stg_applied_layers_idx: Optional[List[int]] = [3, 8, 16],
|
||||
stg_scale: Optional[float] = 0.0,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, defaults to `480`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, defaults to `832`):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, defaults to `81`):
|
||||
The number of frames in the generated video.
|
||||
num_inference_steps (`int`, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `5.0`):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
||||
provided, text embeddings are generated from the `prompt` input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
||||
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
||||
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
||||
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
||||
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
|
||||
The dtype to use for the torch.amp.autocast.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~WanPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
||||
the first element is a list with the generated images and the second element is a list of `bool`s
|
||||
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
transformer_dtype = self.transformer.dtype
|
||||
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
||||
if negative_prompt_embeds is not None:
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
num_frames,
|
||||
torch.float32,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for idx, block in enumerate(self.transformer.blocks):
|
||||
block.forward = types.MethodType(forward_without_stg, block)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if self.do_spatio_temporal_guidance:
|
||||
for idx, block in enumerate(self.transformer.blocks):
|
||||
if idx in stg_applied_layers_idx:
|
||||
block.forward = types.MethodType(forward_with_stg, block)
|
||||
noise_perturb = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = (
|
||||
noise_uncond
|
||||
+ guidance_scale * (noise_pred - noise_uncond)
|
||||
+ self._stg_scale * (noise_pred - noise_perturb)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return WanPipelineOutput(frames=video)
|
||||
@@ -49,6 +49,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -418,7 +419,7 @@ def convert_to_np(image, resolution):
|
||||
|
||||
|
||||
def download_image(url):
|
||||
image = PIL.Image.open(requests.get(url, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@@ -54,6 +54,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2P
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, cast_training_params
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -475,7 +476,7 @@ def convert_to_np(image, resolution):
|
||||
|
||||
|
||||
def download_image(url):
|
||||
image = PIL.Image.open(requests.get(url, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
image = PIL.ImageOps.exif_transpose(image)
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
@@ -59,6 +59,7 @@ from diffusers.schedulers import (
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from diffusers.utils import is_accelerate_available, logging
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
@@ -1435,7 +1436,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
|
||||
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
||||
else:
|
||||
with open(original_config_file, "r") as f:
|
||||
original_config_file = f.read()
|
||||
|
||||
@@ -11,6 +11,7 @@ from diffusion import sampling
|
||||
from torch import nn
|
||||
|
||||
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
MODELS_MAP = {
|
||||
@@ -74,7 +75,7 @@ class DiffusionUncond(nn.Module):
|
||||
|
||||
def download(model_name):
|
||||
url = MODELS_MAP[model_name]["url"]
|
||||
r = requests.get(url, stream=True)
|
||||
r = requests.get(url, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
|
||||
local_filename = f"./{model_name}.ckpt"
|
||||
with open(local_filename, "wb") as fp:
|
||||
|
||||
@@ -13,6 +13,7 @@ from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
renew_vae_attention_paths,
|
||||
renew_vae_resnet_paths,
|
||||
)
|
||||
from diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
@@ -122,7 +123,8 @@ def vae_pt_to_vae_diffuser(
|
||||
):
|
||||
# Only support V1
|
||||
r = requests.get(
|
||||
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml",
|
||||
timeout=DIFFUSERS_REQUEST_TIMEOUT,
|
||||
)
|
||||
io_obj = io.BytesIO(r.content)
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
@@ -129,12 +130,25 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -708,11 +722,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
|
||||
24
src/diffusers/guiders/__init__.py
Normal file
24
src/diffusers/guiders/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
145
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
145
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(GuidanceMixin):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(self, *args):
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
return super().prepare_inputs(*args)
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + (guidance_scale - 1) * normalized_update
|
||||
return pred
|
||||
98
src/diffusers/guiders/classifier_free_guidance.py
Normal file
98
src/diffusers/guiders/classifier_free_guidance.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(GuidanceMixin):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity.
|
||||
|
||||
The original paper proposes scaling and shifting the conditional distribution based on the difference between
|
||||
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
110
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
110
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(GuidanceMixin):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.type_as(cond)
|
||||
213
src/diffusers/guiders/guider_utils.py
Normal file
213
src/diffusers/guiders/guider_utils.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import deprecate, get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models.attention_processor import AttentionProcessor
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GuidanceMixin:
|
||||
r"""Base mixin class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
_input_predictions = None
|
||||
|
||||
def __init__(self):
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._preds: Dict[str, torch.Tensor] = {}
|
||||
self._num_outputs_prepared: int = 0
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._preds = {}
|
||||
self._num_outputs_prepared = 0
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
# Alternating conditional and unconditional inputs as batches
|
||||
inputs = [arg[i % 2] for i in range(num_conditions)]
|
||||
list_of_inputs.append(inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def __call__(self, **kwargs) -> Any:
|
||||
if len(kwargs) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."
|
||||
)
|
||||
return self.forward(**kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def outputs(self) -> Dict[str, torch.Tensor]:
|
||||
return self._preds
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
def _replace_attention_processors(
|
||||
module: torch.nn.Module,
|
||||
pag_applied_layers: Optional[Union[str, List[str]]] = None,
|
||||
skip_context_attention: bool = False,
|
||||
processors: Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]] = None,
|
||||
metadata_name: Optional[str] = None,
|
||||
) -> Optional[List[Tuple[torch.nn.Module, "AttentionProcessor"]]]:
|
||||
if processors is not None and metadata_name is not None:
|
||||
raise ValueError("Cannot pass both `processors` and `metadata_name` at the same time.")
|
||||
if metadata_name is not None:
|
||||
if isinstance(pag_applied_layers, str):
|
||||
pag_applied_layers = [pag_applied_layers]
|
||||
return _replace_layers_with_guidance_processors(
|
||||
module, pag_applied_layers, skip_context_attention, metadata_name
|
||||
)
|
||||
if processors is not None:
|
||||
_replace_layers_with_existing_processors(processors)
|
||||
|
||||
|
||||
def _replace_layers_with_guidance_processors(
|
||||
module: torch.nn.Module,
|
||||
pag_applied_layers: List[str],
|
||||
skip_context_attention: bool,
|
||||
metadata_name: str,
|
||||
) -> List[Tuple[torch.nn.Module, "AttentionProcessor"]]:
|
||||
from ..hooks._common import _ATTENTION_CLASSES
|
||||
from ..hooks._helpers import GuidanceMetadataRegistry
|
||||
|
||||
processors = []
|
||||
for name, submodule in module.named_modules():
|
||||
if (
|
||||
(not isinstance(submodule, _ATTENTION_CLASSES))
|
||||
or (getattr(submodule, "processor", None) is None)
|
||||
or not (
|
||||
any(
|
||||
re.search(pag_layer, name) is not None and not _is_fake_integral_match(pag_layer, name)
|
||||
for pag_layer in pag_applied_layers
|
||||
)
|
||||
)
|
||||
):
|
||||
continue
|
||||
old_attention_processor = submodule.processor
|
||||
metadata = GuidanceMetadataRegistry.get(old_attention_processor.__class__)
|
||||
new_attention_processor_cls = getattr(metadata, metadata_name)
|
||||
new_attention_processor = new_attention_processor_cls()
|
||||
# !!! dunder methods cannot be replaced on instances !!!
|
||||
# if "skip_context_attention" in inspect.signature(new_attention_processor.__call__).parameters:
|
||||
# new_attention_processor.__call__ = partial(
|
||||
# new_attention_processor.__call__, skip_context_attention=skip_context_attention
|
||||
# )
|
||||
submodule.processor = new_attention_processor
|
||||
processors.append((submodule, old_attention_processor))
|
||||
return processors
|
||||
|
||||
|
||||
def _replace_layers_with_existing_processors(processors: List[Tuple[torch.nn.Module, "AttentionProcessor"]]) -> None:
|
||||
for module, proc in processors:
|
||||
module.processor = proc
|
||||
|
||||
|
||||
def _is_fake_integral_match(layer_id, name):
|
||||
layer_id = layer_id.split(".")[-1]
|
||||
name = name.split(".")[-1]
|
||||
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
|
||||
def _raise_guidance_deprecation_warning(
|
||||
*,
|
||||
guidance_scale: Optional[float] = None,
|
||||
guidance_rescale: Optional[float] = None,
|
||||
) -> None:
|
||||
if guidance_scale is not None:
|
||||
msg = "The `guidance_scale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
|
||||
deprecate("guidance_scale", "1.0.0", msg, standard_warn=False)
|
||||
if guidance_rescale is not None:
|
||||
msg = "The `guidance_rescale` argument is deprecated and will be removed in version 1.0.0. Please pass a `GuidanceMixin` object for the `guidance` argument instead."
|
||||
deprecate("guidance_rescale", "1.0.0", msg, standard_warn=False)
|
||||
180
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
180
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .guider_utils import GuidanceMixin, _replace_attention_processors, rescale_noise_cfg
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(GuidanceMixin):
|
||||
"""
|
||||
Perturbed Attention Guidance (PAB): https://huggingface.co/papers/2403.17377
|
||||
|
||||
Args:
|
||||
pag_applied_layers (`str` or `List[str]`):
|
||||
The name of the attention layers where Perturbed Attention Guidance is applied. This can be a single layer
|
||||
name or a list of layer names. The names should either be FQNs (fully qualified names) to each attention
|
||||
layer or a regex pattern that matches the FQNs of the attention layers. For example, if you want to apply
|
||||
PAG to transformer blocks 10 and 20, you can set this to `["transformer_blocks.10",
|
||||
"transformer_blocks.20"]`, or `"transformer_blocks.(10|20)"`.
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
pag_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for perturbed attention guidance.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_perturbed"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pag_applied_layers: Union[str, List[str]],
|
||||
guidance_scale: float = 7.5,
|
||||
pag_scale: float = 3.0,
|
||||
skip_context_attention: bool = False,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pag_applied_layers = pag_applied_layers
|
||||
self.guidance_scale = guidance_scale
|
||||
self.pag_scale = pag_scale
|
||||
self.skip_context_attention = skip_context_attention
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
self._is_pag_batch = False
|
||||
self._original_processors = None
|
||||
self._denoiser = None
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module):
|
||||
self._denoiser = denoiser
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
list_of_inputs.append([arg[0], arg[1], arg[0]])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
if not self._is_cfg_enabled() and self._is_pag_enabled():
|
||||
# If we're predicting pred_cond and pred_perturbed only, we need to set the key to pred_perturbed
|
||||
# to avoid writing into pred_uncond which is not used
|
||||
if self._num_outputs_prepared == 2:
|
||||
key = "pred_perturbed"
|
||||
self._preds[key] = pred
|
||||
|
||||
# Restore the original attention processors if previously replaced
|
||||
if self._is_pag_batch:
|
||||
_replace_attention_processors(self._denoiser, processors=self._original_processors)
|
||||
self._is_pag_batch = False
|
||||
self._original_processors = None
|
||||
|
||||
# Prepare denoiser for perturbed attention prediction if needed
|
||||
if self._is_pag_enabled():
|
||||
should_register_pag = (self._is_cfg_enabled() and self._num_outputs_prepared == 2) or (
|
||||
not self._is_cfg_enabled() and self._num_outputs_prepared == 1
|
||||
)
|
||||
if should_register_pag:
|
||||
self._is_pag_batch = True
|
||||
self._original_processors = _replace_attention_processors(
|
||||
self._denoiser,
|
||||
self.pag_applied_layers,
|
||||
skip_context_attention=self.skip_context_attention,
|
||||
metadata_name="perturbed_attention_guidance_processor_cls",
|
||||
)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
self._denoiser = None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_perturbed: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_pag_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_perturbed
|
||||
pred = pred_cond + self.pag_scale * shift
|
||||
elif not self._is_pag_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_perturbed = pred_cond - pred_perturbed
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.pag_scale * shift_perturbed
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_pag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
def _is_pag_enabled(self) -> bool:
|
||||
is_zero = math.isclose(self.pag_scale, 0.0)
|
||||
return not is_zero
|
||||
235
src/diffusers/guiders/skip_layer_guidance.py
Normal file
235
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import GuidanceMixin, rescale_noise_cfg
|
||||
|
||||
|
||||
class SkipLayerGuidance(GuidanceMixin):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5 Spatio-Temporal Guidance (STG):
|
||||
https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (0.0 < skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module):
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
|
||||
# Register the hooks for layer skipping if the step is within the specified range
|
||||
if skip_start_step < self._step < skip_stop_step:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
|
||||
num_conditions = self.num_conditions
|
||||
list_of_inputs = []
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
list_of_inputs.append([arg] * num_conditions)
|
||||
elif isinstance(arg, (tuple, list)):
|
||||
if len(arg) != 2:
|
||||
raise ValueError(
|
||||
f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "
|
||||
f"with the first element being the conditional input and the second element being the unconditional input or None."
|
||||
)
|
||||
if arg[1] is None:
|
||||
# Only conditioning inputs for all batches
|
||||
list_of_inputs.append([arg[0]] * num_conditions)
|
||||
else:
|
||||
list_of_inputs.append([arg[0], arg[1], arg[0]])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."
|
||||
)
|
||||
return tuple(list_of_inputs)
|
||||
|
||||
def prepare_outputs(self, pred: torch.Tensor) -> None:
|
||||
self._num_outputs_prepared += 1
|
||||
if self._num_outputs_prepared > self.num_conditions:
|
||||
raise ValueError(f"Expected {self.num_conditions} outputs, but prepare_outputs called more times.")
|
||||
key = self._input_predictions[self._num_outputs_prepared - 1]
|
||||
if not self._is_cfg_enabled() and self._is_slg_enabled():
|
||||
# If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
|
||||
# to avoid writing into pred_uncond which is not used
|
||||
if self._num_outputs_prepared == 2:
|
||||
key = "pred_cond_skip"
|
||||
self._preds[key] = pred
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if self.use_original_formulation:
|
||||
return not math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
return not math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
return is_within_range and not is_zero
|
||||
@@ -1,9 +1,25 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
|
||||
32
src/diffusers/hooks/_common.py
Normal file
32
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
276
src/diffusers/hooks/_helpers.py
Normal file
276
src/diffusers/hooks/_helpers.py
Normal file
@@ -0,0 +1,276 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Type
|
||||
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import (
|
||||
CogView4AttnProcessor,
|
||||
CogView4PAGAttnProcessor,
|
||||
CogView4TransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuidanceMetadata:
|
||||
perturbed_attention_guidance_processor_cls: Type = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
skip_block_output_fn: Callable[[Any], Any]
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class GuidanceMetadataRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: GuidanceMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> GuidanceMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
# CogView4
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_guidance_metadata():
|
||||
# CogView4
|
||||
GuidanceMetadataRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=GuidanceMetadata(
|
||||
perturbed_attention_guidance_processor_cls=CogView4PAGAttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
|
||||
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
|
||||
# fmt: on
|
||||
|
||||
|
||||
_register_attention_processors_metadata()
|
||||
_register_guidance_metadata()
|
||||
_register_transformer_blocks_metadata()
|
||||
223
src/diffusers/hooks/first_block_cache.py
Normal file
223
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseMarkedState, HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in lower number of forward passes and faster inference, but might lead to
|
||||
poorer generation quality. A lower threshold may not result in significant generation speedup. The
|
||||
threshold is compared against the absmean difference of the residuals between the current and cached
|
||||
outputs from the first transformer block. If the difference is below the threshold, the forward pass is
|
||||
skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseMarkedState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, shared_state: FBCSharedBlockState, threshold: float):
|
||||
self.shared_state = shared_state
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
is_output_tuple = isinstance(output, tuple)
|
||||
|
||||
if is_output_tuple:
|
||||
hs_residual = output[self._metadata.return_hidden_states_index] - original_hs
|
||||
else:
|
||||
hs_residual = output - original_hs
|
||||
|
||||
hs, ehs = None, None
|
||||
should_compute = self._should_compute_remaining_blocks(hs_residual)
|
||||
self.shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
if is_output_tuple:
|
||||
hs = self.shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
hs = self.shared_state.tail_block_residuals[0] + output
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
assert is_output_tuple
|
||||
ehs = (
|
||||
self.shared_state.tail_block_residuals[1]
|
||||
+ output[self._metadata.return_encoder_hidden_states_index]
|
||||
)
|
||||
|
||||
if is_output_tuple:
|
||||
return_output = [None] * len(output)
|
||||
return_output[self._metadata.return_hidden_states_index] = hs
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = ehs
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hs
|
||||
output = return_output
|
||||
else:
|
||||
if is_output_tuple:
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
||||
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
||||
else:
|
||||
head_block_output = output
|
||||
self.shared_state.head_block_output = head_block_output
|
||||
self.shared_state.head_block_residual = hs_residual
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.shared_state.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hs_residual: torch.Tensor) -> bool:
|
||||
if self.shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hs_residual = self.shared_state.head_block_residual
|
||||
hs_absmean = (hs_residual - prev_hs_residual).abs().mean()
|
||||
prev_hs_mean = prev_hs_residual.abs().mean()
|
||||
diff = (hs_absmean / prev_hs_mean).item()
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, shared_state: FBCSharedBlockState, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.shared_state = shared_state
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
outputs_if_skipped = self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
if not isinstance(outputs_if_skipped, tuple):
|
||||
outputs_if_skipped = (outputs_if_skipped,)
|
||||
original_hs = outputs_if_skipped[self._metadata.return_hidden_states_index]
|
||||
original_ehs = None
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_ehs = outputs_if_skipped[self._metadata.return_encoder_hidden_states_index]
|
||||
|
||||
if self.shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hs_residual, ehs_residual = None, None
|
||||
if isinstance(output, tuple):
|
||||
hs_residual = (
|
||||
output[self._metadata.return_hidden_states_index] - self.shared_state.head_block_output[0]
|
||||
)
|
||||
ehs_residual = (
|
||||
output[self._metadata.return_encoder_hidden_states_index]
|
||||
- self.shared_state.head_block_output[1]
|
||||
)
|
||||
else:
|
||||
hs_residual = output - self.shared_state.head_block_output
|
||||
self.shared_state.tail_block_residuals = (hs_residual, ehs_residual)
|
||||
return output
|
||||
|
||||
output_count = len(outputs_if_skipped)
|
||||
if output_count == 1:
|
||||
return_output = original_hs
|
||||
else:
|
||||
return_output = [None] * output_count
|
||||
return_output[self._metadata.return_hidden_states_index] = original_hs
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = original_ehs
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
shared_state = FBCSharedBlockState()
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Apply FBCHeadBlockHook to '{head_block_name}'")
|
||||
apply_fbc_head_block_hook(head_block, shared_state, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Apply FBCBlockHook to '{name}'")
|
||||
apply_fbc_block_hook(block, shared_state)
|
||||
|
||||
logger.debug(f"Apply FBCBlockHook to tail block '{tail_block_name}'")
|
||||
apply_fbc_block_hook(tail_block, shared_state, is_tail=True)
|
||||
|
||||
|
||||
def apply_fbc_head_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def apply_fbc_block_hook(block: torch.nn.Module, state: FBCSharedBlockState, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -18,11 +18,76 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseState:
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
||||
)
|
||||
|
||||
|
||||
class BaseMarkedState(BaseState):
|
||||
def __init__(self, init_args=None, init_kwargs=None):
|
||||
super().__init__()
|
||||
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._mark_name = None
|
||||
self._state_cache = {}
|
||||
|
||||
def get_current_state(self) -> "BaseMarkedState":
|
||||
if self._mark_name is None:
|
||||
# If no mark name is set, simply return a dummy object since we're not going to be using it
|
||||
return self
|
||||
if self._mark_name not in self._state_cache.keys():
|
||||
self._state_cache[self._mark_name] = self.__class__(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._mark_name]
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
self._mark_name = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._mark_name = None
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
) or _is_dunder_method(name):
|
||||
return object.__getattribute__(self, name)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
return object.__getattribute__(current_state, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in (
|
||||
"get_current_state",
|
||||
"mark_state",
|
||||
"reset",
|
||||
"_init_args",
|
||||
"_init_kwargs",
|
||||
"_mark_name",
|
||||
"_state_cache",
|
||||
) or _is_dunder_method(name):
|
||||
object.__setattr__(self, name, value)
|
||||
else:
|
||||
current_state = BaseMarkedState.get_current_state(self)
|
||||
object.__setattr__(current_state, name, value)
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
@@ -99,6 +164,14 @@ class ModelHook:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
def _mark_state(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `BaseMarkedState`. If so, call `mark_state` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, BaseMarkedState):
|
||||
attr.mark_state(name)
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
@@ -211,9 +284,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -223,6 +297,19 @@ class HookRegistry:
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def _mark_state(self, name: str) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook._mark_state(self._module_ref, name)
|
||||
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._mark_state(name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
@@ -234,3 +321,7 @@ class HookRegistry:
|
||||
if i < len(self._hook_order) - 1:
|
||||
registry_repr += "\n"
|
||||
return f"HookRegistry(\n{registry_repr}\n)"
|
||||
|
||||
|
||||
def _is_dunder_method(name: str) -> bool:
|
||||
return name.startswith("__") and name.endswith("__") and name in dir(object)
|
||||
|
||||
182
src/diffusers/hooks/layer_skip.py
Normal file
182
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,182 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
return self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
return self._metadata.skip_block_output_fn(module, *args, **kwargs)
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = getattr(module, config.fqn, None)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
blocks_found = True
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores)
|
||||
registry.register_hook(hook, name)
|
||||
elif config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook()
|
||||
registry.register_hook(hook, name)
|
||||
else:
|
||||
raise ValueError(
|
||||
"At least one of `skip_attention`, `skip_attention_scores`, or `skip_ff` must be set to True."
|
||||
)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
@@ -44,6 +44,7 @@ from ..utils import (
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..utils.hub_utils import _get_model_file
|
||||
|
||||
|
||||
@@ -443,7 +444,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
|
||||
"Please provide a valid local file path."
|
||||
)
|
||||
|
||||
original_config_file = BytesIO(requests.get(original_config_file).content)
|
||||
original_config_file = BytesIO(requests.get(original_config_file, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
|
||||
@@ -2406,7 +2407,6 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
"timestep_scale_multiplier": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
|
||||
|
||||
@@ -25,6 +27,7 @@ class CacheMixin:
|
||||
Supported caching techniques:
|
||||
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
||||
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
||||
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
|
||||
"""
|
||||
|
||||
_cache_config = None
|
||||
@@ -62,8 +65,10 @@ class CacheMixin:
|
||||
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
|
||||
@@ -72,31 +77,36 @@ class CacheMixin:
|
||||
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
||||
)
|
||||
|
||||
if isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, FasterCacheConfig):
|
||||
if isinstance(config, FasterCacheConfig):
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(config)} is not supported.")
|
||||
|
||||
self._cache_config = config
|
||||
|
||||
def disable_cache(self) -> None:
|
||||
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
return
|
||||
|
||||
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self)
|
||||
if isinstance(self._cache_config, FasterCacheConfig):
|
||||
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
||||
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
else:
|
||||
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
||||
|
||||
@@ -106,3 +116,21 @@ class CacheMixin:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
||||
|
||||
@contextmanager
|
||||
def _cache_context(self):
|
||||
r"""Context manager that provides additional methods for cache management."""
|
||||
cache_context = _CacheContextManager(self)
|
||||
yield cache_context
|
||||
|
||||
|
||||
class _CacheContextManager:
|
||||
def __init__(self, model: CacheMixin):
|
||||
self.model = model
|
||||
|
||||
def mark_state(self, name: str) -> None:
|
||||
from ..hooks import HookRegistry
|
||||
|
||||
if self.model.is_cache_enabled:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(self.model)
|
||||
registry._mark_state(name)
|
||||
|
||||
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
single_block_samples = ()
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
||||
single_block_samples = single_block_samples + (hidden_states,)
|
||||
|
||||
# controlnet block
|
||||
controlnet_block_samples = ()
|
||||
|
||||
@@ -460,3 +460,84 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
|
||||
### ===== Custom attention processors for guidance methods ===== ###
|
||||
|
||||
|
||||
class CogView4PAGAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
skip_context_attention: bool = False,
|
||||
) -> torch.Tensor:
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
# 4. Attention
|
||||
if skip_context_attention:
|
||||
hidden_states = value
|
||||
else:
|
||||
# PAG uses a custom attention mask for perturbed attention path:
|
||||
# - Create attention mask with `float("-inf")` for image tokens and `0.0` for text tokens
|
||||
# - Set diagonal to `0.0` for attention between image tokens
|
||||
seq_length = text_seq_length + image_seq_length
|
||||
perturbed_attention_mask = hidden_states.new_full((seq_length, seq_length), float("-inf"))
|
||||
perturbed_attention_mask[:text_seq_length, :text_seq_length] = 0.0
|
||||
perturbed_attention_mask.fill_diagonal_(0.0)
|
||||
perturbed_attention_mask = perturbed_attention_mask.unsqueeze(0).unsqueeze(0)
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=perturbed_attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> torch.Tensor:
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
residual = hidden_states
|
||||
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
||||
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
||||
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = hidden_states.clip(-65504, 65504)
|
||||
|
||||
return hidden_states
|
||||
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -508,20 +513,21 @@ class FluxTransformer2DModel(
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
)
|
||||
|
||||
else:
|
||||
hidden_states = block(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
@@ -531,12 +537,7 @@ class FluxTransformer2DModel(
|
||||
if controlnet_single_block_samples is not None:
|
||||
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
||||
interval_control = int(np.ceil(interval_control))
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
||||
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
+ controlnet_single_block_samples[index_block // interval_control]
|
||||
)
|
||||
|
||||
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
||||
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from transformers import AutoTokenizer, GlmModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...guiders import ClassifierFreeGuidance, GuidanceMixin, _raise_guidance_deprecation_warning
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, CogView4Transformer2DModel
|
||||
@@ -426,6 +427,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 1024,
|
||||
guidance: Optional[GuidanceMixin] = None,
|
||||
) -> Union[CogView4PipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -514,6 +516,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
_raise_guidance_deprecation_warning(guidance_scale=guidance_scale)
|
||||
if guidance is None:
|
||||
guidance = ClassifierFreeGuidance(guidance_scale=guidance_scale)
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
@@ -608,46 +614,47 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
conds = [prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left]
|
||||
prompt_embeds, negative_prompt_embeds, original_size, target_size, crops_coords_top_left = [[v] for v in conds]
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
self._current_timestep = t
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
guidance.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
|
||||
guidance.prepare_models(self.transformer)
|
||||
latents, prompt_embeds, original_size, target_size, crops_coords_top_left = guidance.prepare_inputs(
|
||||
latents,
|
||||
(prompt_embeds[0], negative_prompt_embeds[0]),
|
||||
original_size[0],
|
||||
target_size[0],
|
||||
crops_coords_top_left[0],
|
||||
)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
for batch_index, (latent, condition, original_size_c, target_size_c, crop_coord_c) in enumerate(
|
||||
zip(latents, prompt_embeds, original_size, target_size, crops_coords_top_left)
|
||||
):
|
||||
cc.mark_state(f"batch_{batch_index}")
|
||||
latent = latent.to(transformer_dtype)
|
||||
timestep = t.expand(latent.shape[0])
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent,
|
||||
encoder_hidden_states=condition,
|
||||
timestep=timestep,
|
||||
original_size=original_size,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
original_size=original_size_c,
|
||||
target_size=target_size_c,
|
||||
crop_coords=crop_coord_c,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
guidance.prepare_outputs(noise_pred)
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
outputs = guidance.outputs
|
||||
noise_pred = guidance(**outputs)
|
||||
latents = self.scheduler.step(noise_pred, t, latents[0], return_dict=False)[0]
|
||||
guidance.cleanup_models(self.transformer)
|
||||
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
@@ -656,8 +663,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
prompt_embeds = [callback_outputs.pop("prompt_embeds", prompt_embeds[0])]
|
||||
negative_prompt_embeds = [
|
||||
callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds[0])
|
||||
]
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
@@ -906,7 +906,7 @@ class FluxPipeline(
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -917,6 +917,7 @@ class FluxPipeline(
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
@@ -932,6 +933,8 @@ class FluxPipeline(
|
||||
if do_true_cfg:
|
||||
if negative_image_embeds is not None:
|
||||
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
||||
|
||||
cc.mark_state("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / 1000,
|
||||
|
||||
@@ -224,11 +224,13 @@ class FluxFillPipeline(
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels
|
||||
)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
vae_latent_channels=latent_channels,
|
||||
vae_latent_channels=self.latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
@@ -493,10 +495,38 @@ class FluxFillPipeline(
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
@@ -507,6 +537,9 @@ class FluxFillPipeline(
|
||||
mask_image=None,
|
||||
masked_image_latents=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
@@ -624,9 +657,11 @@ class FluxFillPipeline(
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux_img2img.FluxImg2ImgPipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
@@ -636,28 +671,41 @@ class FluxFillPipeline(
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype), latent_image_ids
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != self.latent_channels:
|
||||
image_latents = self._encode_vae_image(image=image, generator=generator)
|
||||
else:
|
||||
image_latents = image
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
image_latents = torch.cat([image_latents], dim=0)
|
||||
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
||||
return latents, latent_image_ids
|
||||
|
||||
@property
|
||||
@@ -687,6 +735,7 @@ class FluxFillPipeline(
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 30.0,
|
||||
@@ -731,6 +780,12 @@ class FluxFillPipeline(
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
@@ -794,6 +849,7 @@ class FluxFillPipeline(
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
@@ -809,6 +865,9 @@ class FluxFillPipeline(
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
init_image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
@@ -838,47 +897,9 @@ class FluxFillPipeline(
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare mask and masked image latents
|
||||
if masked_image_latents is not None:
|
||||
masked_image_latents = masked_image_latents.to(latents.device)
|
||||
else:
|
||||
image = self.image_processor.preprocess(image, height=height, width=width)
|
||||
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
masked_image = image * (1 - mask_image)
|
||||
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
height, width = image.shape[-2:]
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask_image,
|
||||
masked_image,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
# 4. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
@@ -893,6 +914,54 @@ class FluxFillPipeline(
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
latents, latent_image_ids = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Prepare mask and masked image latents
|
||||
if masked_image_latents is not None:
|
||||
masked_image_latents = masked_image_latents.to(latents.device)
|
||||
else:
|
||||
mask_image = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
masked_image = init_image * (1 - mask_image)
|
||||
masked_image = masked_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
height, width = init_image.shape[-2:]
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask_image,
|
||||
masked_image,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
masked_image_latents = torch.cat((masked_image_latents, mask), dim=-1)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
|
||||
@@ -683,7 +683,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -693,6 +693,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -705,6 +706,7 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if do_true_cfg:
|
||||
cc.mark_state("uncond")
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin
|
||||
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
@@ -121,7 +121,7 @@ def retrieve_timesteps(
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
|
||||
class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionLoraLoaderMixin, IPAdapterMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Kolors.
|
||||
|
||||
@@ -129,8 +129,8 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
The pipeline also inherits the following loading methods:
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
||||
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
||||
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
||||
|
||||
Args:
|
||||
|
||||
@@ -706,7 +706,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -719,6 +719,7 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
@@ -75,6 +75,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> # Generate video
|
||||
>>> generator = torch.Generator("cuda").manual_seed(0)
|
||||
>>> # Text-only conditioning is also supported without the need to pass `conditions`
|
||||
>>> video = pipe(
|
||||
... conditions=[condition1, condition2],
|
||||
... prompt=prompt,
|
||||
@@ -223,7 +224,7 @@ def retrieve_latents(
|
||||
|
||||
class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for image-to-video generation.
|
||||
Pipeline for text/image/video-to-video generation.
|
||||
|
||||
Reference: https://github.com/Lightricks/LTX-Video
|
||||
|
||||
@@ -482,9 +483,6 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if conditions is not None and (image is not None or video is not None):
|
||||
raise ValueError("If `conditions` is provided, `image` and `video` must not be provided.")
|
||||
|
||||
if conditions is None and (image is None and video is None):
|
||||
raise ValueError("If `conditions` is not provided, `image` or `video` must be provided.")
|
||||
|
||||
if conditions is None:
|
||||
if isinstance(image, list) and isinstance(frame_index, list) and len(image) != len(frame_index):
|
||||
raise ValueError(
|
||||
@@ -642,9 +640,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
conditions: List[torch.Tensor],
|
||||
condition_strength: List[float],
|
||||
condition_frame_index: List[int],
|
||||
conditions: Optional[List[torch.Tensor]] = None,
|
||||
condition_strength: Optional[List[float]] = None,
|
||||
condition_frame_index: Optional[List[int]] = None,
|
||||
batch_size: int = 1,
|
||||
num_channels_latents: int = 128,
|
||||
height: int = 512,
|
||||
@@ -654,7 +652,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
generator: Optional[torch.Generator] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
|
||||
num_latent_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
@@ -662,77 +660,80 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
condition_latent_frames_mask = torch.zeros((batch_size, num_latent_frames), device=device, dtype=torch.float32)
|
||||
if len(conditions) > 0:
|
||||
condition_latent_frames_mask = torch.zeros(
|
||||
(batch_size, num_latent_frames), device=device, dtype=torch.float32
|
||||
)
|
||||
|
||||
extra_conditioning_latents = []
|
||||
extra_conditioning_video_ids = []
|
||||
extra_conditioning_mask = []
|
||||
extra_conditioning_num_latents = 0
|
||||
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
|
||||
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
|
||||
condition_latents = self._normalize_latents(
|
||||
condition_latents, self.vae.latents_mean, self.vae.latents_std
|
||||
).to(device, dtype=dtype)
|
||||
extra_conditioning_latents = []
|
||||
extra_conditioning_video_ids = []
|
||||
extra_conditioning_mask = []
|
||||
extra_conditioning_num_latents = 0
|
||||
for data, strength, frame_index in zip(conditions, condition_strength, condition_frame_index):
|
||||
condition_latents = retrieve_latents(self.vae.encode(data), generator=generator)
|
||||
condition_latents = self._normalize_latents(
|
||||
condition_latents, self.vae.latents_mean, self.vae.latents_std
|
||||
).to(device, dtype=dtype)
|
||||
|
||||
num_data_frames = data.size(2)
|
||||
num_cond_frames = condition_latents.size(2)
|
||||
num_data_frames = data.size(2)
|
||||
num_cond_frames = condition_latents.size(2)
|
||||
|
||||
if frame_index == 0:
|
||||
latents[:, :, :num_cond_frames] = torch.lerp(
|
||||
latents[:, :, :num_cond_frames], condition_latents, strength
|
||||
)
|
||||
condition_latent_frames_mask[:, :num_cond_frames] = strength
|
||||
if frame_index == 0:
|
||||
latents[:, :, :num_cond_frames] = torch.lerp(
|
||||
latents[:, :, :num_cond_frames], condition_latents, strength
|
||||
)
|
||||
condition_latent_frames_mask[:, :num_cond_frames] = strength
|
||||
|
||||
else:
|
||||
if num_data_frames > 1:
|
||||
if num_cond_frames < num_prefix_latent_frames:
|
||||
raise ValueError(
|
||||
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
|
||||
)
|
||||
else:
|
||||
if num_data_frames > 1:
|
||||
if num_cond_frames < num_prefix_latent_frames:
|
||||
raise ValueError(
|
||||
f"Number of latent frames must be at least {num_prefix_latent_frames} but got {num_data_frames}."
|
||||
)
|
||||
|
||||
if num_cond_frames > num_prefix_latent_frames:
|
||||
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
|
||||
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
|
||||
latents[:, :, start_frame:end_frame] = torch.lerp(
|
||||
latents[:, :, start_frame:end_frame],
|
||||
condition_latents[:, :, num_prefix_latent_frames:],
|
||||
strength,
|
||||
)
|
||||
condition_latent_frames_mask[:, start_frame:end_frame] = strength
|
||||
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
|
||||
if num_cond_frames > num_prefix_latent_frames:
|
||||
start_frame = frame_index // self.vae_temporal_compression_ratio + num_prefix_latent_frames
|
||||
end_frame = start_frame + num_cond_frames - num_prefix_latent_frames
|
||||
latents[:, :, start_frame:end_frame] = torch.lerp(
|
||||
latents[:, :, start_frame:end_frame],
|
||||
condition_latents[:, :, num_prefix_latent_frames:],
|
||||
strength,
|
||||
)
|
||||
condition_latent_frames_mask[:, start_frame:end_frame] = strength
|
||||
condition_latents = condition_latents[:, :, :num_prefix_latent_frames]
|
||||
|
||||
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
condition_latents = torch.lerp(noise, condition_latents, strength)
|
||||
noise = randn_tensor(condition_latents.shape, generator=generator, device=device, dtype=dtype)
|
||||
condition_latents = torch.lerp(noise, condition_latents, strength)
|
||||
|
||||
condition_video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
condition_latents.size(2),
|
||||
latent_height,
|
||||
latent_width,
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
patch_size_t=self.transformer_temporal_patch_size,
|
||||
device=device,
|
||||
)
|
||||
condition_video_ids = self._scale_video_ids(
|
||||
condition_video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
scale_factor_t=self.vae_temporal_compression_ratio,
|
||||
frame_index=frame_index,
|
||||
device=device,
|
||||
)
|
||||
condition_latents = self._pack_latents(
|
||||
condition_latents,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
condition_conditioning_mask = torch.full(
|
||||
condition_latents.shape[:2], strength, device=device, dtype=dtype
|
||||
)
|
||||
condition_video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
condition_latents.size(2),
|
||||
latent_height,
|
||||
latent_width,
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
patch_size_t=self.transformer_temporal_patch_size,
|
||||
device=device,
|
||||
)
|
||||
condition_video_ids = self._scale_video_ids(
|
||||
condition_video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
scale_factor_t=self.vae_temporal_compression_ratio,
|
||||
frame_index=frame_index,
|
||||
device=device,
|
||||
)
|
||||
condition_latents = self._pack_latents(
|
||||
condition_latents,
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
condition_conditioning_mask = torch.full(
|
||||
condition_latents.shape[:2], strength, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
extra_conditioning_latents.append(condition_latents)
|
||||
extra_conditioning_video_ids.append(condition_video_ids)
|
||||
extra_conditioning_mask.append(condition_conditioning_mask)
|
||||
extra_conditioning_num_latents += condition_latents.size(1)
|
||||
extra_conditioning_latents.append(condition_latents)
|
||||
extra_conditioning_video_ids.append(condition_video_ids)
|
||||
extra_conditioning_mask.append(condition_conditioning_mask)
|
||||
extra_conditioning_num_latents += condition_latents.size(1)
|
||||
|
||||
video_ids = self._prepare_video_ids(
|
||||
batch_size,
|
||||
@@ -743,7 +744,10 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
patch_size=self.transformer_spatial_patch_size,
|
||||
device=device,
|
||||
)
|
||||
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
|
||||
if len(conditions) > 0:
|
||||
conditioning_mask = condition_latent_frames_mask.gather(1, video_ids[:, 0])
|
||||
else:
|
||||
conditioning_mask, extra_conditioning_num_latents = None, 0
|
||||
video_ids = self._scale_video_ids(
|
||||
video_ids,
|
||||
scale_factor=self.vae_spatial_compression_ratio,
|
||||
@@ -755,7 +759,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
latents, self.transformer_spatial_patch_size, self.transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
if len(extra_conditioning_latents) > 0:
|
||||
if len(conditions) > 0 and len(extra_conditioning_latents) > 0:
|
||||
latents = torch.cat([*extra_conditioning_latents, latents], dim=1)
|
||||
video_ids = torch.cat([*extra_conditioning_video_ids, video_ids], dim=2)
|
||||
conditioning_mask = torch.cat([*extra_conditioning_mask, conditioning_mask], dim=1)
|
||||
@@ -955,7 +959,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
frame_index = [condition.frame_index for condition in conditions]
|
||||
image = [condition.image for condition in conditions]
|
||||
video = [condition.video for condition in conditions]
|
||||
else:
|
||||
elif image is not None or video is not None:
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
num_conditions = 1
|
||||
@@ -999,32 +1003,34 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
vae_dtype = self.vae.dtype
|
||||
|
||||
conditioning_tensors = []
|
||||
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
|
||||
image, video, frame_index, strength
|
||||
):
|
||||
if condition_image is not None:
|
||||
condition_tensor = (
|
||||
self.video_processor.preprocess(condition_image, height, width)
|
||||
.unsqueeze(2)
|
||||
.to(device, dtype=vae_dtype)
|
||||
)
|
||||
elif condition_video is not None:
|
||||
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
|
||||
num_frames_input = condition_tensor.size(2)
|
||||
num_frames_output = self.trim_conditioning_sequence(
|
||||
condition_frame_index, num_frames_input, num_frames
|
||||
)
|
||||
condition_tensor = condition_tensor[:, :, :num_frames_output]
|
||||
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
|
||||
else:
|
||||
raise ValueError("Either `image` or `video` must be provided in the `LTXVideoCondition`.")
|
||||
is_conditioning_image_or_video = image is not None or video is not None
|
||||
if is_conditioning_image_or_video:
|
||||
for condition_image, condition_video, condition_frame_index, condition_strength in zip(
|
||||
image, video, frame_index, strength
|
||||
):
|
||||
if condition_image is not None:
|
||||
condition_tensor = (
|
||||
self.video_processor.preprocess(condition_image, height, width)
|
||||
.unsqueeze(2)
|
||||
.to(device, dtype=vae_dtype)
|
||||
)
|
||||
elif condition_video is not None:
|
||||
condition_tensor = self.video_processor.preprocess_video(condition_video, height, width)
|
||||
num_frames_input = condition_tensor.size(2)
|
||||
num_frames_output = self.trim_conditioning_sequence(
|
||||
condition_frame_index, num_frames_input, num_frames
|
||||
)
|
||||
condition_tensor = condition_tensor[:, :, :num_frames_output]
|
||||
condition_tensor = condition_tensor.to(device, dtype=vae_dtype)
|
||||
else:
|
||||
raise ValueError("Either `image` or `video` must be provided for conditioning.")
|
||||
|
||||
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
|
||||
raise ValueError(
|
||||
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
|
||||
f"but got {condition_tensor.size(2)} frames."
|
||||
)
|
||||
conditioning_tensors.append(condition_tensor)
|
||||
if condition_tensor.size(2) % self.vae_temporal_compression_ratio != 1:
|
||||
raise ValueError(
|
||||
f"Number of frames in the video must be of the form (k * {self.vae_temporal_compression_ratio} + 1) "
|
||||
f"but got {condition_tensor.size(2)} frames."
|
||||
)
|
||||
conditioning_tensors.append(condition_tensor)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
@@ -1045,7 +1051,7 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
video_coords = video_coords.float()
|
||||
video_coords[:, 0] = video_coords[:, 0] * (1.0 / frame_rate)
|
||||
|
||||
init_latents = latents.clone()
|
||||
init_latents = latents.clone() if is_conditioning_image_or_video else None
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
video_coords = torch.cat([video_coords, video_coords], dim=0)
|
||||
@@ -1065,15 +1071,15 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
|
||||
if image_cond_noise_scale > 0:
|
||||
if image_cond_noise_scale > 0 and init_latents is not None:
|
||||
# Add timestep-dependent noise to the hard-conditioning latents
|
||||
# This helps with motion continuity, especially when conditioned on a single frame
|
||||
latents = self.add_noise_to_image_conditioning_latents(
|
||||
@@ -1086,17 +1092,20 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
)
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
conditioning_mask_model_input = (
|
||||
torch.cat([conditioning_mask, conditioning_mask])
|
||||
if self.do_classifier_free_guidance
|
||||
else conditioning_mask
|
||||
)
|
||||
if is_conditioning_image_or_video:
|
||||
conditioning_mask_model_input = (
|
||||
torch.cat([conditioning_mask, conditioning_mask])
|
||||
if self.do_classifier_free_guidance
|
||||
else conditioning_mask
|
||||
)
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0]).unsqueeze(-1).float()
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
if is_conditioning_image_or_video:
|
||||
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
@@ -1115,8 +1124,11 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
denoised_latents = self.scheduler.step(
|
||||
-noise_pred, t, latents, per_token_timesteps=timestep, return_dict=False
|
||||
)[0]
|
||||
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
|
||||
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
|
||||
if is_conditioning_image_or_video:
|
||||
tokens_to_denoise_mask = (t / 1000 - 1e-6 < (1.0 - conditioning_mask)).unsqueeze(-1)
|
||||
latents = torch.where(tokens_to_denoise_mask, denoised_latents, latents)
|
||||
else:
|
||||
latents = denoised_latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@@ -1134,7 +1146,9 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
latents = latents[:, extra_conditioning_num_latents:]
|
||||
if is_conditioning_image_or_video:
|
||||
latents = latents[:, extra_conditioning_num_latents:]
|
||||
|
||||
latents = self._unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
|
||||
@@ -778,7 +778,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
)
|
||||
|
||||
# 7. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -792,6 +792,7 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
|
||||
|
||||
cc.mark_state("cond_uncond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
|
||||
@@ -52,6 +52,7 @@ from ...schedulers import (
|
||||
UnCLIPScheduler,
|
||||
)
|
||||
from ...utils import is_accelerate_available, logging
|
||||
from ...utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from ..latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
|
||||
from ..paint_by_example import PaintByExampleImageEncoder
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -1324,7 +1325,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
|
||||
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
original_config_file = BytesIO(requests.get(config_url, timeout=DIFFUSERS_REQUEST_TIMEOUT).content)
|
||||
else:
|
||||
with open(original_config_file, "r") as f:
|
||||
original_config_file = f.read()
|
||||
|
||||
@@ -519,7 +519,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar, self.transformer._cache_context() as cc:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
@@ -528,6 +528,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
cc.mark_state("cond")
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
@@ -537,6 +538,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
cc.mark_state("uncond")
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
|
||||
@@ -40,6 +40,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
DIFFUSERS_REQUEST_TIMEOUT = 60
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -2,6 +2,81 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SkipLayerGuidance(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FasterCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -17,6 +92,21 @@ class FasterCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FirstBlockCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HookRegistry(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -32,6 +122,21 @@ class HookRegistry(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LayerSkipConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -51,6 +156,14 @@ def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_first_block_cache(*args, **kwargs):
|
||||
requires_backends(apply_first_block_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .import_utils import BACKENDS_MAPPING, is_imageio_available
|
||||
|
||||
|
||||
@@ -29,7 +30,7 @@ def load_image(
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = PIL.Image.open(image)
|
||||
else:
|
||||
|
||||
@@ -26,6 +26,7 @@ import requests
|
||||
from numpy.linalg import norm
|
||||
from packaging import version
|
||||
|
||||
from .constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
from .import_utils import (
|
||||
BACKENDS_MAPPING,
|
||||
is_accelerate_available,
|
||||
@@ -594,7 +595,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
# local_path can be passed to correct images of tests
|
||||
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
|
||||
elif arry.startswith("http://") or arry.startswith("https://"):
|
||||
response = requests.get(arry)
|
||||
response = requests.get(arry, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
arry = np.load(BytesIO(response.content))
|
||||
elif os.path.isfile(arry):
|
||||
@@ -615,7 +616,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
|
||||
|
||||
|
||||
def load_pt(url: str, map_location: str):
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
arry = torch.load(BytesIO(response.content), map_location=map_location)
|
||||
return arry
|
||||
@@ -634,7 +635,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
|
||||
"""
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
image = PIL.Image.open(requests.get(image, stream=True).raw)
|
||||
image = PIL.Image.open(requests.get(image, stream=True, timeout=DIFFUSERS_REQUEST_TIMEOUT).raw)
|
||||
elif os.path.isfile(image):
|
||||
image = PIL.Image.open(image)
|
||||
else:
|
||||
|
||||
@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
|
||||
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
|
||||
|
||||
|
||||
def unwrap_module(module):
|
||||
"""Unwraps a module if it was compiled with torch.compile()"""
|
||||
return module._orig_mod if is_compiled_module(module) else module
|
||||
|
||||
|
||||
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
|
||||
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from diffusers.utils.testing_utils import (
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
check_qkv_fusion_matches_attn_procs_length,
|
||||
@@ -44,7 +45,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = CogVideoXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
@@ -34,11 +35,12 @@ from ..test_pipelines_common import (
|
||||
|
||||
|
||||
class FluxPipelineFastTests(
|
||||
unittest.TestCase,
|
||||
PipelineTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
|
||||
|
||||
from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
to_np,
|
||||
@@ -43,7 +44,11 @@ enable_full_determinism()
|
||||
|
||||
|
||||
class HunyuanVideoPipelineFastTests(
|
||||
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = HunyuanVideoPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
|
||||
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LTXVideoTransformer3DModel(
|
||||
in_channels=8,
|
||||
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
num_attention_heads=4,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=32,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
caption_channels=32,
|
||||
)
|
||||
|
||||
|
||||
@@ -33,13 +33,15 @@ from diffusers.utils.testing_utils import (
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
|
||||
class MochiPipelineFastTests(
|
||||
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = MochiPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
|
||||
@@ -2631,7 +2632,7 @@ class FasterCacheTesterMixin:
|
||||
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.faster_cache_config)
|
||||
output = run_forward(pipe).flatten().flatten()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FasterCache disabled
|
||||
@@ -2738,6 +2739,55 @@ class FasterCacheTesterMixin:
|
||||
self.assertTrue(state.cache is None, "Cache should be reset to None.")
|
||||
|
||||
|
||||
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
|
||||
# of the box once there is better cache support/implementation
|
||||
class FirstBlockCacheTesterMixin:
|
||||
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
|
||||
# that will not satisfy the expected properties of the denoiser for caching to be effective
|
||||
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
|
||||
|
||||
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# Run inference without FirstBlockCache
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache enabled
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.first_block_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# Run inference with FirstBlockCache disabled
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fbc_enabled, atol=expected_atol
|
||||
), "FirstBlockCache outputs should not differ much."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fbc_disabled, atol=1e-4
|
||||
), "Outputs from normal inference and after disabling cache should not differ."
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import requests
|
||||
from packaging.version import parse
|
||||
|
||||
from ..src.diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
# GitHub repository details
|
||||
USER = "huggingface"
|
||||
@@ -27,7 +30,11 @@ def fetch_all_branches(user, repo):
|
||||
page = 1 # Start from first page
|
||||
while True:
|
||||
# Make a request to the GitHub API for the branches
|
||||
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page})
|
||||
response = requests.get(
|
||||
f"https://api.github.com/repos/{user}/{repo}/branches",
|
||||
params={"page": page},
|
||||
timeout=DIFFUSERS_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
|
||||
@@ -17,6 +17,8 @@ import os
|
||||
|
||||
import requests
|
||||
|
||||
from ..src.diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
# Configuration
|
||||
LIBRARY_NAME = "diffusers"
|
||||
@@ -26,7 +28,7 @@ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
||||
|
||||
def check_pypi_for_latest_release(library_name):
|
||||
"""Check PyPI for the latest release of the library."""
|
||||
response = requests.get(f"https://pypi.org/pypi/{library_name}/json")
|
||||
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data["info"]["version"]
|
||||
@@ -38,7 +40,7 @@ def check_pypi_for_latest_release(library_name):
|
||||
def get_github_release_info(github_repo):
|
||||
"""Fetch the latest release info from GitHub."""
|
||||
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
|
||||
response = requests.get(url)
|
||||
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
Reference in New Issue
Block a user