mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
22 Commits
add-uv-scr
...
pyramid-at
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1afc0fc616 | ||
|
|
995b82fb67 | ||
|
|
d95d61ae3f | ||
|
|
3de2c18964 | ||
|
|
c52cf422d0 | ||
|
|
18b7d6d9e2 | ||
|
|
a5f51bbab3 | ||
|
|
37d23669cf | ||
|
|
6b1f55ec97 | ||
|
|
9cb4e876bc | ||
|
|
6816fe15a4 | ||
|
|
6265b65469 | ||
|
|
afd0c176d1 | ||
|
|
b3547c647e | ||
|
|
9f6987fdb0 | ||
|
|
955e4f7c35 | ||
|
|
ae4abb14a3 | ||
|
|
1c97e04fa5 | ||
|
|
d5c738defe | ||
|
|
373710167a | ||
|
|
6d3bdb5511 | ||
|
|
67c729d448 |
@@ -15,7 +15,7 @@
|
||||
|
||||
# CogVideoX
|
||||
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
|
||||
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://huggingface.co/papers/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
@@ -120,6 +120,45 @@ It is also worth noting that torchao quantization is fully compatible with [torc
|
||||
- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
|
||||
- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)
|
||||
|
||||
### Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states aren't that different between successive steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
|
||||
|
||||
Enable PAB with [`~PyramidAttentionBroadcastMixin.enable_pyramind_attention_broadcast`] on any pipeline and keep track of the current inference timestep in the pipeline.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(
|
||||
spatial_attn_skip_range=2,
|
||||
spatial_attn_timestep_range=[100, 850],
|
||||
)
|
||||
|
||||
prompt = (
|
||||
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
|
||||
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
|
||||
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
|
||||
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
|
||||
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
|
||||
"atmosphere of this unique musical performance."
|
||||
)
|
||||
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
|
||||
export_to_video(video, "output.mp4", fps=8)
|
||||
```
|
||||
|
||||
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|
||||
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
|
||||
| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 |
|
||||
| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 |
|
||||
| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 |
|
||||
|
||||
## CogVideoXPipeline
|
||||
|
||||
[[autodoc]] CogVideoXPipeline
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||

|
||||
|
||||
[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
|
||||
[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
@@ -70,6 +70,37 @@ Without torch.compile(): Average inference time: 16.246 seconds.
|
||||
With torch.compile(): Average inference time: 14.573 seconds.
|
||||
```
|
||||
|
||||
### Pyramid Attention Broadcast
|
||||
|
||||
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
|
||||
|
||||
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation.
|
||||
|
||||
PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with Latte:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LattePipeline
|
||||
from diffusers.utils import export_to_gif
|
||||
|
||||
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(
|
||||
spatial_attn_skip_range=2,
|
||||
cross_attn_skip_range=6,
|
||||
spatial_attn_timestep_range=[100, 800],
|
||||
cross_attn_timestep_range=[100, 800],
|
||||
)
|
||||
|
||||
prompt = "A small cactus with a happy face in the Sahara desert."
|
||||
videos = pipe(prompt).frames[0]
|
||||
export_to_gif(videos, "latte.gif")
|
||||
```
|
||||
|
||||
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|
||||
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
|
||||
| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 |
|
||||
|
||||
## LattePipeline
|
||||
|
||||
[[autodoc]] LattePipeline
|
||||
|
||||
278
src/diffusers/models/hooks.py
Normal file
278
src/diffusers/models/hooks.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
|
||||
with PyTorch existing hooks is that they get passed along the kwargs.
|
||||
"""
|
||||
|
||||
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when a model is initialized.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module attached to this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Hook that is executed just before the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass will be executed just after this event.
|
||||
args (`Tuple[Any]`):
|
||||
The positional arguments passed to the module.
|
||||
kwargs (`Dict[Str, Any]`):
|
||||
The keyword arguments passed to the module.
|
||||
|
||||
Returns:
|
||||
`Tuple[Tuple[Any], Dict[Str, Any]]`:
|
||||
A tuple with the treated `args` and `kwargs`.
|
||||
"""
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
||||
r"""
|
||||
Hook that is executed just after the forward method of the model.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module whose forward pass been executed just before this event.
|
||||
output (`Any`):
|
||||
The output of the module.
|
||||
|
||||
Returns:
|
||||
`Any`: The processed `output`.
|
||||
"""
|
||||
return output
|
||||
|
||||
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
r"""
|
||||
Hook that is executed when the hook is detached from a module.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module detached from this hook.
|
||||
"""
|
||||
return module
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
return module
|
||||
|
||||
|
||||
class SequentialHook(ModelHook):
|
||||
r"""A hook that can contain several hooks and iterates through them at each event."""
|
||||
|
||||
def __init__(self, *hooks):
|
||||
self.hooks = hooks
|
||||
|
||||
def init_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.init_hook(module)
|
||||
return module
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
for hook in self.hooks:
|
||||
args, kwargs = hook.pre_forward(module, *args, **kwargs)
|
||||
return args, kwargs
|
||||
|
||||
def post_forward(self, module, output):
|
||||
for hook in self.hooks:
|
||||
output = hook.post_forward(module, output)
|
||||
return output
|
||||
|
||||
def detach_hook(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.detach_hook(module)
|
||||
return module
|
||||
|
||||
def reset_state(self, module):
|
||||
for hook in self.hooks:
|
||||
module = hook.reset_state(module)
|
||||
return module
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastHook(ModelHook):
|
||||
def __init__(
|
||||
self,
|
||||
skip_callback: Callable[[torch.nn.Module], bool],
|
||||
# skip_range: int,
|
||||
# timestep_range: Tuple[int, int],
|
||||
# timestep_callback: Callable[[], Union[torch.LongTensor, int]],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# self.skip_range = skip_range
|
||||
# self.timestep_range = timestep_range
|
||||
# self.timestep_callback = timestep_callback
|
||||
self.skip_callback = skip_callback
|
||||
|
||||
self.cache = None
|
||||
self._iteration = 0
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
|
||||
# current_timestep = self.timestep_callback()
|
||||
# is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
|
||||
# should_compute_attention = self._iteration % self.skip_range == 0
|
||||
|
||||
# if not is_within_timestep_range or should_compute_attention:
|
||||
# output = module._old_forward(*args, **kwargs)
|
||||
# else:
|
||||
# output = self.attention_cache
|
||||
|
||||
if self.cache is not None and self.skip_callback(module):
|
||||
output = self.cache
|
||||
else:
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
|
||||
self.cache = output
|
||||
return output
|
||||
|
||||
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
self.cache = None
|
||||
self._iteration = 0
|
||||
return module
|
||||
|
||||
|
||||
class LayerSkipHook(ModelHook):
|
||||
def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.skip_callback = skip_
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
|
||||
if self.skip_callback(module):
|
||||
# We want to skip this layer, so we have to return the input of the current layer
|
||||
# as output of the next layer. But at this point, we don't have information about
|
||||
# the arguments required by next layer. Even if we did, order matters unless we
|
||||
# always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states,
|
||||
# temb, etc. TODO(aryan): implement correctly later
|
||||
output = None
|
||||
else:
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
|
||||
|
||||
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
|
||||
r"""
|
||||
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
|
||||
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
|
||||
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
hook (`ModelHook`):
|
||||
The hook to attach.
|
||||
append (`bool`, *optional*, defaults to `False`):
|
||||
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
original_hook = hook
|
||||
|
||||
if append and getattr(module, "_diffusers_hook", None) is not None:
|
||||
old_hook = module._diffusers_hook
|
||||
remove_hook_from_module(module)
|
||||
hook = SequentialHook(old_hook, hook)
|
||||
|
||||
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
|
||||
# If we already put some hook on this module, we replace it with the new one.
|
||||
old_forward = module._old_forward
|
||||
else:
|
||||
old_forward = module.forward
|
||||
module._old_forward = old_forward
|
||||
|
||||
module = hook.init_hook(module)
|
||||
module._diffusers_hook = hook
|
||||
|
||||
if hasattr(original_hook, "new_forward"):
|
||||
new_forward = original_hook.new_forward
|
||||
else:
|
||||
|
||||
def new_forward(module, *args, **kwargs):
|
||||
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
|
||||
output = module._old_forward(*args, **kwargs)
|
||||
return module._diffusers_hook.post_forward(module, output)
|
||||
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
else:
|
||||
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
|
||||
"""
|
||||
Removes any hook attached to a module via `add_hook_to_module`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to attach a hook to.
|
||||
recurse (`bool`, defaults to `False`):
|
||||
Whether to remove the hooks recursively
|
||||
|
||||
Returns:
|
||||
`torch.nn.Module`:
|
||||
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
|
||||
"""
|
||||
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.detach_hook(module)
|
||||
delattr(module, "_diffusers_hook")
|
||||
|
||||
if hasattr(module, "_old_forward"):
|
||||
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
|
||||
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
|
||||
if "GraphModuleImpl" in str(type(module)):
|
||||
module.__class__.forward = module._old_forward
|
||||
else:
|
||||
module.forward = module._old_forward
|
||||
delattr(module, "_old_forward")
|
||||
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
remove_hook_from_module(child, recurse)
|
||||
|
||||
return module
|
||||
@@ -11,7 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -19,6 +20,7 @@ from torch import nn
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import AttentionProcessor
|
||||
from ..embeddings import PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -165,6 +167,66 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor()
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -779,6 +779,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -856,6 +857,7 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -894,6 +896,8 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
@@ -622,6 +622,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -700,6 +701,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -755,6 +757,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -675,6 +675,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -761,6 +762,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -810,6 +812,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -722,6 +722,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
@@ -809,6 +810,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -868,6 +870,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
# Discard any padding frames that were added for CogVideoX 1.5
|
||||
latents = latents[:, additional_frames:]
|
||||
|
||||
@@ -701,6 +701,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default call parameters
|
||||
@@ -787,6 +788,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -842,6 +844,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
|
||||
@@ -623,7 +623,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
clean_caption: bool = True,
|
||||
mask_feature: bool = True,
|
||||
enable_temporal_attentions: bool = True,
|
||||
decode_chunk_size: Optional[int] = None,
|
||||
decode_chunk_size: int = 14,
|
||||
) -> Union[LattePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -719,6 +719,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Default height and width to transformer
|
||||
@@ -780,6 +781,7 @@ class LattePipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
@@ -836,8 +838,10 @@ class LattePipeline(DiffusionPipeline):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if not output_type == "latents":
|
||||
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
|
||||
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
|
||||
@@ -1088,6 +1088,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
||||
functions correctly when applying enable_model_cpu_offload.
|
||||
"""
|
||||
|
||||
if hasattr(self, "_diffusers_hook"):
|
||||
self._diffusers_hook.reset_state()
|
||||
|
||||
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
||||
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
||||
return
|
||||
|
||||
315
src/diffusers/pipelines/pyramid_attention_broadcast_utils.py
Normal file
315
src/diffusers/pipelines/pyramid_attention_broadcast_utils.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from ..models.attention_processor import Attention
|
||||
from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module
|
||||
from ..utils import logging
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention,)
|
||||
|
||||
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks"
|
||||
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyramidAttentionBroadcastConfig:
|
||||
spatial_attention_block_skip_range: Optional[int] = None
|
||||
temporal_attention_block_skip_range: Optional[int] = None
|
||||
cross_attention_block_skip_range: Optional[int] = None
|
||||
|
||||
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
|
||||
|
||||
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
|
||||
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastState:
|
||||
def __init__(self) -> None:
|
||||
self.iteration = 0
|
||||
|
||||
def reset_state(self):
|
||||
self.iteration = 0
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(
|
||||
pipeline: DiffusionPipeline,
|
||||
config: Optional[PyramidAttentionBroadcastConfig] = None,
|
||||
denoiser: Optional[nn.Module] = None,
|
||||
):
|
||||
if config is None:
|
||||
config = PyramidAttentionBroadcastConfig()
|
||||
|
||||
if (
|
||||
config.spatial_attention_block_skip_range is None
|
||||
and config.temporal_attention_block_skip_range is None
|
||||
and config.cross_attention_block_skip_range is None
|
||||
):
|
||||
logger.warning(
|
||||
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
|
||||
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
|
||||
"To avoid this warning, please set one of the above parameters."
|
||||
)
|
||||
config.spatial_attention_block_skip_range = 2
|
||||
|
||||
if denoiser is None:
|
||||
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
|
||||
|
||||
for name, module in denoiser.named_modules():
|
||||
if not isinstance(module, _ATTENTION_CLASSES):
|
||||
continue
|
||||
if isinstance(module, Attention):
|
||||
_apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast_on_module(
|
||||
module: Attention,
|
||||
block_skip_range: int,
|
||||
timestep_skip_range: Tuple[int, int],
|
||||
current_timestep_callback: Callable[[], int],
|
||||
):
|
||||
module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState()
|
||||
min_timestep, max_timestep = timestep_skip_range
|
||||
|
||||
def skip_callback(attention_module: nn.Module) -> bool:
|
||||
pab_state: PyramidAttentionBroadcastState = attention_module._pyramid_attention_broadcast_state
|
||||
current_timestep = current_timestep_callback()
|
||||
is_within_timestep_range = min_timestep < current_timestep < max_timestep
|
||||
|
||||
if is_within_timestep_range:
|
||||
# As soon as the current timestep is within the timestep range, we start skipping attention computation.
|
||||
# The following inference steps will compute the attention every `block_skip_range` steps.
|
||||
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
|
||||
pab_state.iteration += 1
|
||||
print(current_timestep, is_within_timestep_range, should_compute_attention)
|
||||
return not should_compute_attention
|
||||
|
||||
# We are still not yet in the phase of inference where skipping attention is possible without minimal quality
|
||||
# loss, as described in the paper. So, the attention computation cannot be skipped
|
||||
return False
|
||||
|
||||
hook = PyramidAttentionBroadcastHook(skip_callback=skip_callback)
|
||||
add_hook_to_module(module, hook, append=True)
|
||||
|
||||
|
||||
def _apply_pyramid_attention_broadcast_on_attention_class(
|
||||
pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig
|
||||
):
|
||||
# Similar check as PEFT to determine if a string layer name matches a module name
|
||||
is_spatial_self_attention = (
|
||||
any(
|
||||
f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers
|
||||
)
|
||||
and config.spatial_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
is_temporal_self_attention = (
|
||||
any(
|
||||
f"{identifier}." in name or identifier == name
|
||||
for identifier in config.temporal_attention_block_identifiers
|
||||
)
|
||||
and config.temporal_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
is_cross_attention = (
|
||||
any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)
|
||||
and config.cross_attention_block_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
|
||||
block_skip_range, timestep_skip_range = None, None
|
||||
if is_spatial_self_attention:
|
||||
block_skip_range = config.spatial_attention_block_skip_range
|
||||
timestep_skip_range = config.spatial_attention_timestep_skip_range
|
||||
elif is_temporal_self_attention:
|
||||
block_skip_range = config.temporal_attention_block_skip_range
|
||||
timestep_skip_range = config.temporal_attention_timestep_skip_range
|
||||
elif is_cross_attention:
|
||||
block_skip_range = config.cross_attention_block_skip_range
|
||||
timestep_skip_range = config.cross_attention_timestep_skip_range
|
||||
|
||||
if block_skip_range is None or timestep_skip_range is None:
|
||||
logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")
|
||||
return
|
||||
|
||||
def current_timestep_callback():
|
||||
return pipeline._current_timestep
|
||||
|
||||
apply_pyramid_attention_broadcast_on_module(
|
||||
module, block_skip_range, timestep_skip_range, current_timestep_callback
|
||||
)
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastMixin:
|
||||
r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""
|
||||
|
||||
def _enable_pyramid_attention_broadcast(self) -> None:
|
||||
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
|
||||
|
||||
for name, module in denoiser.named_modules():
|
||||
if isinstance(module, Attention):
|
||||
is_spatial_attention = (
|
||||
any(x in name for x in self._pab_spatial_attn_layer_identifiers)
|
||||
and self._pab_spatial_attn_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
is_temporal_attention = (
|
||||
any(x in name for x in self._pab_temporal_attn_layer_identifiers)
|
||||
and self._pab_temporal_attn_skip_range is not None
|
||||
and not module.is_cross_attention
|
||||
)
|
||||
is_cross_attention = (
|
||||
any(x in name for x in self._pab_cross_attn_layer_identifiers)
|
||||
and self._pab_cross_attn_skip_range is not None
|
||||
and module.is_cross_attention
|
||||
)
|
||||
|
||||
if is_spatial_attention:
|
||||
skip_range = self._pab_spatial_attn_skip_range
|
||||
timestep_range = self._pab_spatial_attn_timestep_range
|
||||
if is_temporal_attention:
|
||||
skip_range = self._pab_temporal_attn_skip_range
|
||||
timestep_range = self._pab_temporal_attn_timestep_range
|
||||
if is_cross_attention:
|
||||
skip_range = self._pab_cross_attn_skip_range
|
||||
timestep_range = self._pab_cross_attn_timestep_range
|
||||
|
||||
if skip_range is None:
|
||||
continue
|
||||
|
||||
# logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}")
|
||||
print(f"Enabling Pyramid Attention Broadcast in layer: {name}")
|
||||
|
||||
add_hook_to_module(
|
||||
module,
|
||||
PyramidAttentionBroadcastHook(
|
||||
skip_range=skip_range,
|
||||
timestep_range=timestep_range,
|
||||
timestep_callback=self._pyramid_attention_broadcast_timestep_callback,
|
||||
),
|
||||
append=True,
|
||||
)
|
||||
|
||||
def _disable_pyramid_attention_broadcast(self) -> None:
|
||||
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
|
||||
for name, module in denoiser.named_modules():
|
||||
logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}")
|
||||
remove_hook_from_module(module)
|
||||
|
||||
def _pyramid_attention_broadcast_timestep_callback(self):
|
||||
return self._current_timestep
|
||||
|
||||
def enable_pyramid_attention_broadcast(
|
||||
self,
|
||||
spatial_attn_skip_range: Optional[int] = None,
|
||||
spatial_attn_timestep_range: Tuple[int, int] = (100, 800),
|
||||
temporal_attn_skip_range: Optional[int] = None,
|
||||
cross_attn_skip_range: Optional[int] = None,
|
||||
temporal_attn_timestep_range: Tuple[int, int] = (100, 800),
|
||||
cross_attn_timestep_range: Tuple[int, int] = (100, 800),
|
||||
spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
|
||||
temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"],
|
||||
cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
|
||||
) -> None:
|
||||
r"""
|
||||
Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation
|
||||
systematically as described in the paper: [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).
|
||||
|
||||
Args:
|
||||
spatial_attn_skip_range (`int`, *optional*):
|
||||
The attention block to execute after skipping intermediate spatial attention blocks. If set to the
|
||||
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
|
||||
different tolerances to how much attention computation can be reused based on the differences between
|
||||
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
|
||||
Setting this value to `2` is recommended for different models PAB has been experimented with.
|
||||
temporal_attn_skip_range (`int`, *optional*):
|
||||
The attention block to execute after skipping intermediate temporal attention blocks. If set to the
|
||||
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
|
||||
different tolerances to how much attention computation can be reused based on the differences between
|
||||
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
|
||||
Setting this value to `4` is recommended for different models PAB has been experimented with.
|
||||
cross_attn_skip_range (`int`, *optional*):
|
||||
The attention block to execute after skipping intermediate cross attention blocks. If set to the value
|
||||
`N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
|
||||
different tolerances to how much attention computation can be reused based on the differences between
|
||||
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
|
||||
Setting this value to `6` is recommended for different models PAB has been experimented with.
|
||||
spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The timestep range between which PAB will remain activated in spatial attention blocks. While
|
||||
activated, PAB will re-use attention computations between inference steps.
|
||||
temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The timestep range between which PAB will remain activated in temporal attention blocks. While
|
||||
activated, PAB will re-use attention computations between inference steps.
|
||||
cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
|
||||
The timestep range between which PAB will remain activated in cross attention blocks. While activated,
|
||||
PAB will re-use attention computations between inference steps.
|
||||
"""
|
||||
|
||||
if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]:
|
||||
raise ValueError(
|
||||
"Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
|
||||
)
|
||||
if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]:
|
||||
raise ValueError(
|
||||
"Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
|
||||
)
|
||||
if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]:
|
||||
raise ValueError(
|
||||
"Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
|
||||
)
|
||||
|
||||
self._pab_spatial_attn_skip_range = spatial_attn_skip_range
|
||||
self._pab_temporal_attn_skip_range = temporal_attn_skip_range
|
||||
self._pab_cross_attn_skip_range = cross_attn_skip_range
|
||||
self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range
|
||||
self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range
|
||||
self._pab_cross_attn_timestep_range = cross_attn_timestep_range
|
||||
self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers
|
||||
self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers
|
||||
self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers
|
||||
|
||||
self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range
|
||||
|
||||
self._enable_pyramid_attention_broadcast()
|
||||
|
||||
def disable_pyramid_attention_broadcast(self) -> None:
|
||||
r"""Disables the pyramid attention broadcast sampling mechanism."""
|
||||
|
||||
self._pab_spatial_attn_skip_range = None
|
||||
self._pab_temporal_attn_skip_range = None
|
||||
self._pab_cross_attn_skip_range = None
|
||||
self._pab_spatial_attn_timestep_range = None
|
||||
self._pab_temporal_attn_timestep_range = None
|
||||
self._pab_cross_attn_timestep_range = None
|
||||
self._pab_spatial_attn_layer_identifiers = None
|
||||
self._pab_temporal_attn_layer_identifiers = None
|
||||
self._pab_cross_attn_layer_identifiers = None
|
||||
self._pab_enabled = False
|
||||
|
||||
@property
|
||||
def pyramid_attention_broadcast_enabled(self):
|
||||
return hasattr(self, "_pab_enabled") and self._pab_enabled
|
||||
@@ -21,6 +21,7 @@ import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
numpy_cosine_similarity_distance,
|
||||
@@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
|
||||
@@ -71,7 +72,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
out_channels=4,
|
||||
time_embed_dim=2,
|
||||
text_embed_dim=32, # Must match with tiny-random-t5
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
sample_width=2, # latent width: 2 -> final width: 16
|
||||
sample_height=2, # latent height: 2 -> final height: 16
|
||||
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
|
||||
@@ -319,6 +320,54 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pyramid_attention_broadcast(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 4
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames # [B, F, C, H, W]
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
|
||||
assert pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
num_pab_processors = sum(
|
||||
[
|
||||
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
|
||||
for processor in pipe.transformer.attn_processors.values()
|
||||
]
|
||||
)
|
||||
assert num_pab_processors == num_layers
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.disable_pyramid_attention_broadcast()
|
||||
assert not pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
# We need to use higher tolerance because we are using a random model. With a converged/trained
|
||||
# model, the tolerance can be lower.
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=0.2
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
|
||||
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=0.2
|
||||
), "Original outputs should match when PAB is disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -22,6 +22,7 @@ from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
|
||||
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
@@ -61,7 +62,7 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
|
||||
@@ -76,7 +77,7 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
out_channels=4,
|
||||
time_embed_dim=2,
|
||||
text_embed_dim=32, # Must match with tiny-random-t5
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
sample_width=2, # latent width: 2 -> final width: 16
|
||||
sample_height=2, # latent height: 2 -> final height: 16
|
||||
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
|
||||
@@ -342,6 +343,54 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pyramid_attention_broadcast(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 4
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames # [B, F, C, H, W]
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
|
||||
assert pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
num_pab_processors = sum(
|
||||
[
|
||||
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
|
||||
for processor in pipe.transformer.attn_processors.values()
|
||||
]
|
||||
)
|
||||
assert num_pab_processors == num_layers
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.disable_pyramid_attention_broadcast()
|
||||
assert not pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
# We need to use higher tolerance because we are using a random model. With a converged/trained
|
||||
# model, the tolerance can be lower.
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=0.2
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
|
||||
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=0.2
|
||||
), "Original outputs should match when PAB is disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -21,6 +21,7 @@ from PIL import Image
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
|
||||
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
@@ -53,7 +54,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
)
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = CogVideoXTransformer3DModel(
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
|
||||
@@ -65,7 +66,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
out_channels=4,
|
||||
time_embed_dim=2,
|
||||
text_embed_dim=32, # Must match with tiny-random-t5
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
sample_width=2, # latent width: 2 -> final width: 16
|
||||
sample_height=2, # latent height: 2 -> final height: 16
|
||||
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
|
||||
@@ -323,3 +324,51 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
def test_pyramid_attention_broadcast(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 4
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames # [B, F, C, H, W]
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
|
||||
assert pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
num_pab_processors = sum(
|
||||
[
|
||||
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
|
||||
for processor in pipe.transformer.attn_processors.values()
|
||||
]
|
||||
)
|
||||
assert num_pab_processors == num_layers
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.disable_pyramid_attention_broadcast()
|
||||
assert not pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
# We need to use higher tolerance because we are using a random model. With a converged/trained
|
||||
# model, the tolerance can be lower.
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=0.2
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
assert np.allclose(
|
||||
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
|
||||
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=0.2
|
||||
), "Original outputs should match when PAB is disabled."
|
||||
|
||||
@@ -53,11 +53,11 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
|
||||
def get_dummy_components(self):
|
||||
def get_dummy_components(self, num_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = LatteTransformer3DModel(
|
||||
sample_size=8,
|
||||
num_layers=1,
|
||||
num_layers=num_layers,
|
||||
patch_size=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=3,
|
||||
@@ -264,6 +264,47 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
|
||||
|
||||
def test_pyramid_attention_broadcast(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
num_layers = 4
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames # [B, F, C, H, W]
|
||||
original_image_slice = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
|
||||
assert pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["num_inference_steps"] = 4
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
pipe.disable_pyramid_attention_broadcast()
|
||||
assert not pipe.pyramid_attention_broadcast_enabled
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
frames = pipe(**inputs).frames
|
||||
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
|
||||
|
||||
# We need to use higher tolerance because we are using a random model. With a converged/trained
|
||||
# model, the tolerance can be lower.
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_enabled, atol=0.25
|
||||
), "PAB outputs should not differ much in specified timestep range."
|
||||
print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max())
|
||||
assert np.allclose(
|
||||
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25
|
||||
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_pab_disabled, atol=0.25
|
||||
), "Original outputs should match when PAB is disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user