Compare commits

...

22 Commits

Author SHA1 Message Date
Aryan
1afc0fc616 update 2024-12-06 13:58:41 +01:00
Aryan
995b82fb67 update 2024-12-06 12:14:46 +01:00
Aryan
d95d61ae3f Merge branch 'main' into pyramid-attention-broadcast 2024-12-06 09:33:41 +01:00
Aryan
3de2c18964 Merge branch 'main' into pyramid-attention-broadcast 2024-11-09 22:08:29 +05:30
Aryan
c52cf422d0 Pyramid Attention Broadcast rewrite + introduce hooks (#9826)
* rewrite implementation with hooks

* make style

* update
2024-11-08 21:45:58 +05:30
Aryan
18b7d6d9e2 Merge branch 'main' into pyramid-attention-broadcast 2024-11-05 19:55:41 +05:30
Aryan
a5f51bbab3 Merge branch 'main' into pyramid-attention-broadcast 2024-11-01 01:16:13 +05:30
Aryan
37d23669cf Merge branch 'main' into pyramid-attention-broadcast 2024-10-30 11:09:11 +01:00
Aryan
6b1f55ec97 Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-10-16 03:26:58 +05:30
Aryan
9cb4e876bc Update docs/source/en/api/pipelines/cogvideox.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2024-10-16 03:26:45 +05:30
Aryan
6816fe15a4 Merge branch 'main' into pyramid-attention-broadcast 2024-10-16 03:26:23 +05:30
Aryan
6265b65469 update 2024-10-05 22:22:01 +02:00
Aryan
afd0c176d1 add tests 2024-10-04 08:51:21 +02:00
Aryan
b3547c647e add docs 2024-10-04 08:16:24 +02:00
Aryan
9f6987fdb0 make style 2024-10-03 21:22:26 +02:00
Aryan
955e4f7c35 Merge branch 'main' into pyramid-attention-broadcast 2024-10-03 21:22:10 +02:00
Aryan
ae4abb14a3 update 2024-10-03 21:21:35 +02:00
Aryan
1c97e04fa5 Merge branch 'main' into pyramid-attention-broadcast 2024-10-03 12:36:33 +05:30
Aryan
d5c738defe make style 2024-10-03 09:06:19 +02:00
Aryan
373710167a update 2024-10-03 09:06:02 +02:00
Aryan
6d3bdb5511 add coauthor
Co-Authored-By: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com>
2024-10-03 08:34:57 +02:00
Aryan
67c729d448 start pyramid attention broadcast 2024-10-01 03:31:06 +02:00
16 changed files with 955 additions and 14 deletions

View File

@@ -15,7 +15,7 @@
# CogVideoX
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://arxiv.org/abs/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
[CogVideoX: Text-to-Video Diffusion Models with An Expert Transformer](https://huggingface.co/papers/2408.06072) from Tsinghua University & ZhipuAI, by Zhuoyi Yang, Jiayan Teng, Wendi Zheng, Ming Ding, Shiyu Huang, Jiazheng Xu, Yuanming Yang, Wenyi Hong, Xiaohan Zhang, Guanyu Feng, Da Yin, Xiaotao Gu, Yuxuan Zhang, Weihan Wang, Yean Cheng, Ting Liu, Bin Xu, Yuxiao Dong, Jie Tang.
The abstract from the paper is:
@@ -120,6 +120,45 @@ It is also worth noting that torchao quantization is fully compatible with [torc
- [torchao](https://gist.github.com/a-r-r-o-w/4d9732d17412888c885480c6521a9897)
- [quanto](https://gist.github.com/a-r-r-o-w/31be62828b00a9292821b85c1017effa)
### Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps and reusing cached attention states. The attention states aren't that different between successive steps. The most prominent difference is in the spatial attention blocks, not as much in the temporal attention blocks, and finally the least in the cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by the temporal and spatial attention blocks. By combining other techniques like sequence parallelism and classifier-free guidance parallelism, PAB achieves near real-time video generation.
Enable PAB with [`~PyramidAttentionBroadcastMixin.enable_pyramind_attention_broadcast`] on any pipeline and keep track of the current inference timestep in the pipeline.
```python
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.float16)
pipe.to("cuda")
pipe.enable_pyramid_attention_broadcast(
spatial_attn_skip_range=2,
spatial_attn_timestep_range=[100, 850],
)
prompt = (
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
"atmosphere of this unique musical performance."
)
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
export_to_video(video, "output.mp4", fps=8)
```
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
| Cog-2b T2V | 12.55 | 35.342 | 35.342 | 86.915 | 63.914 | 1.359 |
| Cog-5b T2V | 19.66 | 40.945 | 40.945 | 246.152 | 168.59 | 1.460 |
| Cog-5b I2V | 19.764 | 42.74 | 42.74 | 246.867 | 170.111 | 1.451 |
## CogVideoXPipeline
[[autodoc]] CogVideoXPipeline

View File

@@ -16,7 +16,7 @@
![latte text-to-video](https://github.com/Vchitect/Latte/blob/52bc0029899babbd6e9250384c83d8ed2670ff7a/visuals/latte.gif?raw=true)
[Latte: Latent Diffusion Transformer for Video Generation](https://arxiv.org/abs/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
[Latte: Latent Diffusion Transformer for Video Generation](https://huggingface.co/papers/2401.03048) from Monash University, Shanghai AI Lab, Nanjing University, and Nanyang Technological University.
The abstract from the paper is:
@@ -70,6 +70,37 @@ Without torch.compile(): Average inference time: 16.246 seconds.
With torch.compile(): Average inference time: 14.573 seconds.
```
### Pyramid Attention Broadcast
[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Pyramid Attention Broadcast (PAB) is a method that speeds up inference in diffusion models by systematically skipping attention computations between successive inference steps, and re-using cached attention states. This is due to the realization that the attention states do not differ too much numerically between successive steps. This difference is most significant/prominent in the spatial attention blocks, lesser so in temporal attention blocks, and least in cross attention blocks. Therefore, many cross attention computation blocks can be skipped, followed by temporal and spatial attention blocks. By combining other techniques like Sequence Parallelism and CFG Parallelism, the authors achieve near real-time video generation.
PAB can be enabled easily on any pipeline by deriving from the [`PyramidAttentionBroadcastMixin`] and keeping track of current inference timestep in the pipeline. Minimal example to demonstrate how to use PAB with Latte:
```python
import torch
from diffusers import LattePipeline
from diffusers.utils import export_to_gif
pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
pipe.enable_pyramid_attention_broadcast(
spatial_attn_skip_range=2,
cross_attn_skip_range=6,
spatial_attn_timestep_range=[100, 800],
cross_attn_timestep_range=[100, 800],
)
prompt = "A small cactus with a happy face in the Sahara desert."
videos = pipe(prompt).frames[0]
export_to_gif(videos, "latte.gif")
```
| model | model_memory | normal_memory | pab_memory | normal_time | pab_time | speedup |
|:----------:|:--------------:|:---------------:|:------------:|:-------------:|:----------:|:---------:|
| Latte | 11.007 | 25.594 | 25.594 | 28.026 | 24.073 | 1.164 |
## LattePipeline
[[autodoc]] LattePipeline

View File

@@ -0,0 +1,278 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, Callable, Dict, Tuple
import torch
# Reference: https://github.com/huggingface/accelerate/blob/ba7ab93f5e688466ea56908ea3b056fae2f9a023/src/accelerate/hooks.py
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
with PyTorch existing hooks is that they get passed along the kwargs.
"""
def init_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when a model is initialized.
Args:
module (`torch.nn.Module`):
The module attached to this hook.
"""
return module
def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
r"""
Hook that is executed just before the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass will be executed just after this event.
args (`Tuple[Any]`):
The positional arguments passed to the module.
kwargs (`Dict[Str, Any]`):
The keyword arguments passed to the module.
Returns:
`Tuple[Tuple[Any], Dict[Str, Any]]`:
A tuple with the treated `args` and `kwargs`.
"""
return args, kwargs
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
r"""
Hook that is executed just after the forward method of the model.
Args:
module (`torch.nn.Module`):
The module whose forward pass been executed just before this event.
output (`Any`):
The output of the module.
Returns:
`Any`: The processed `output`.
"""
return output
def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
r"""
Hook that is executed when the hook is detached from a module.
Args:
module (`torch.nn.Module`):
The module detached from this hook.
"""
return module
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
return module
class SequentialHook(ModelHook):
r"""A hook that can contain several hooks and iterates through them at each event."""
def __init__(self, *hooks):
self.hooks = hooks
def init_hook(self, module):
for hook in self.hooks:
module = hook.init_hook(module)
return module
def pre_forward(self, module, *args, **kwargs):
for hook in self.hooks:
args, kwargs = hook.pre_forward(module, *args, **kwargs)
return args, kwargs
def post_forward(self, module, output):
for hook in self.hooks:
output = hook.post_forward(module, output)
return output
def detach_hook(self, module):
for hook in self.hooks:
module = hook.detach_hook(module)
return module
def reset_state(self, module):
for hook in self.hooks:
module = hook.reset_state(module)
return module
class PyramidAttentionBroadcastHook(ModelHook):
def __init__(
self,
skip_callback: Callable[[torch.nn.Module], bool],
# skip_range: int,
# timestep_range: Tuple[int, int],
# timestep_callback: Callable[[], Union[torch.LongTensor, int]],
) -> None:
super().__init__()
# self.skip_range = skip_range
# self.timestep_range = timestep_range
# self.timestep_callback = timestep_callback
self.skip_callback = skip_callback
self.cache = None
self._iteration = 0
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
# current_timestep = self.timestep_callback()
# is_within_timestep_range = self.timestep_range[0] < current_timestep < self.timestep_range[1]
# should_compute_attention = self._iteration % self.skip_range == 0
# if not is_within_timestep_range or should_compute_attention:
# output = module._old_forward(*args, **kwargs)
# else:
# output = self.attention_cache
if self.cache is not None and self.skip_callback(module):
output = self.cache
else:
output = module._old_forward(*args, **kwargs)
return module._diffusers_hook.post_forward(module, output)
def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
self.cache = output
return output
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module:
self.cache = None
self._iteration = 0
return module
class LayerSkipHook(ModelHook):
def __init__(self, skip_: Callable[[torch.nn.Module], bool]) -> None:
super().__init__()
self.skip_callback = skip_
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any:
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
if self.skip_callback(module):
# We want to skip this layer, so we have to return the input of the current layer
# as output of the next layer. But at this point, we don't have information about
# the arguments required by next layer. Even if we did, order matters unless we
# always pass kwargs. But that is not the case usually with hidden_states, encoder_hidden_states,
# temb, etc. TODO(aryan): implement correctly later
output = None
else:
output = module._old_forward(*args, **kwargs)
return module._diffusers_hook.post_forward(module, output)
def add_hook_to_module(module: torch.nn.Module, hook: ModelHook, append: bool = False):
r"""
Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
this behavior and restore the original `forward` method, use `remove_hook_from_module`.
<Tip warning={true}>
If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
</Tip>
Args:
module (`torch.nn.Module`):
The module to attach a hook to.
hook (`ModelHook`):
The hook to attach.
append (`bool`, *optional*, defaults to `False`):
Whether the hook should be chained with an existing one (if module already contains a hook) or not.
Returns:
`torch.nn.Module`:
The same module, with the hook attached (the module is modified in place, so the result can be discarded).
"""
original_hook = hook
if append and getattr(module, "_diffusers_hook", None) is not None:
old_hook = module._diffusers_hook
remove_hook_from_module(module)
hook = SequentialHook(old_hook, hook)
if hasattr(module, "_diffusers_hook") and hasattr(module, "_old_forward"):
# If we already put some hook on this module, we replace it with the new one.
old_forward = module._old_forward
else:
old_forward = module.forward
module._old_forward = old_forward
module = hook.init_hook(module)
module._diffusers_hook = hook
if hasattr(original_hook, "new_forward"):
new_forward = original_hook.new_forward
else:
def new_forward(module, *args, **kwargs):
args, kwargs = module._diffusers_hook.pre_forward(module, *args, **kwargs)
output = module._old_forward(*args, **kwargs)
return module._diffusers_hook.post_forward(module, output)
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
else:
module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
return module
def remove_hook_from_module(module: torch.nn.Module, recurse: bool = False) -> torch.nn.Module:
"""
Removes any hook attached to a module via `add_hook_to_module`.
Args:
module (`torch.nn.Module`):
The module to attach a hook to.
recurse (`bool`, defaults to `False`):
Whether to remove the hooks recursively
Returns:
`torch.nn.Module`:
The same module, with the hook detached (the module is modified in place, so the result can be discarded).
"""
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.detach_hook(module)
delattr(module, "_diffusers_hook")
if hasattr(module, "_old_forward"):
# Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
# Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
if "GraphModuleImpl" in str(type(module)):
module.__class__.forward = module._old_forward
else:
module.forward = module._old_forward
delattr(module, "_old_forward")
if recurse:
for child in module.children():
remove_hook_from_module(child, recurse)
return module

View File

@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Dict, Optional, Union
import torch
from torch import nn
@@ -19,6 +20,7 @@ from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
from ..attention import BasicTransformerBlock
from ..attention_processor import AttentionProcessor
from ..embeddings import PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
@@ -165,6 +167,66 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -779,6 +779,7 @@ class AllegroPipeline(DiffusionPipeline):
negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -856,6 +857,7 @@ class AllegroPipeline(DiffusionPipeline):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -894,6 +896,8 @@ class AllegroPipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
video = self.decode_latents(latents)

View File

@@ -622,6 +622,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -700,6 +701,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -755,6 +757,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]

View File

@@ -675,6 +675,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -761,6 +762,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -810,6 +812,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

View File

@@ -722,6 +722,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
negative_prompt_embeds=negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._attention_kwargs = attention_kwargs
self._interrupt = False
@@ -809,6 +810,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -868,6 +870,8 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
# Discard any padding frames that were added for CogVideoX 1.5
latents = latents[:, additional_frames:]

View File

@@ -701,6 +701,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Default call parameters
@@ -787,6 +788,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -842,6 +844,8 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)

View File

@@ -623,7 +623,7 @@ class LattePipeline(DiffusionPipeline):
clean_caption: bool = True,
mask_feature: bool = True,
enable_temporal_attentions: bool = True,
decode_chunk_size: Optional[int] = None,
decode_chunk_size: int = 14,
) -> Union[LattePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
@@ -719,6 +719,7 @@ class LattePipeline(DiffusionPipeline):
negative_prompt_embeds,
)
self._guidance_scale = guidance_scale
self._current_timestep = None
self._interrupt = False
# 2. Default height and width to transformer
@@ -780,6 +781,7 @@ class LattePipeline(DiffusionPipeline):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
@@ -836,8 +838,10 @@ class LattePipeline(DiffusionPipeline):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if not output_type == "latents":
video = self.decode_latents(latents, video_length, decode_chunk_size=14)
self._current_timestep = None
if not output_type == "latent":
video = self.decode_latents(latents, video_length, decode_chunk_size=decode_chunk_size)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents

View File

@@ -1088,6 +1088,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
functions correctly when applying enable_model_cpu_offload.
"""
if hasattr(self, "_diffusers_hook"):
self._diffusers_hook.reset_state()
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
# `enable_model_cpu_offload` has not be called, so silently do nothing
return

View File

@@ -0,0 +1,315 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple
import torch.nn as nn
from ..models.attention_processor import Attention
from ..models.hooks import PyramidAttentionBroadcastHook, add_hook_to_module, remove_hook_from_module
from ..utils import logging
from .pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_ATTENTION_CLASSES = (Attention,)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = "temporal_transformer_blocks"
_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
class PyramidAttentionBroadcastConfig:
spatial_attention_block_skip_range: Optional[int] = None
temporal_attention_block_skip_range: Optional[int] = None
cross_attention_block_skip_range: Optional[int] = None
spatial_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
class PyramidAttentionBroadcastState:
def __init__(self) -> None:
self.iteration = 0
def reset_state(self):
self.iteration = 0
def apply_pyramid_attention_broadcast(
pipeline: DiffusionPipeline,
config: Optional[PyramidAttentionBroadcastConfig] = None,
denoiser: Optional[nn.Module] = None,
):
if config is None:
config = PyramidAttentionBroadcastConfig()
if (
config.spatial_attention_block_skip_range is None
and config.temporal_attention_block_skip_range is None
and config.cross_attention_block_skip_range is None
):
logger.warning(
"Pyramid Attention Broadcast requires one or more of `spatial_attention_block_skip_range`, `temporal_attention_block_skip_range` "
"or `cross_attention_block_skip_range` parameters to be set to an integer, not `None`. Defaulting to using `spatial_attention_block_skip_range=2`. "
"To avoid this warning, please set one of the above parameters."
)
config.spatial_attention_block_skip_range = 2
if denoiser is None:
denoiser = pipeline.transformer if hasattr(pipeline, "transformer") else pipeline.unet
for name, module in denoiser.named_modules():
if not isinstance(module, _ATTENTION_CLASSES):
continue
if isinstance(module, Attention):
_apply_pyramid_attention_broadcast_on_attention_class(pipeline, name, module, config)
def apply_pyramid_attention_broadcast_on_module(
module: Attention,
block_skip_range: int,
timestep_skip_range: Tuple[int, int],
current_timestep_callback: Callable[[], int],
):
module._pyramid_attention_broadcast_state = PyramidAttentionBroadcastState()
min_timestep, max_timestep = timestep_skip_range
def skip_callback(attention_module: nn.Module) -> bool:
pab_state: PyramidAttentionBroadcastState = attention_module._pyramid_attention_broadcast_state
current_timestep = current_timestep_callback()
is_within_timestep_range = min_timestep < current_timestep < max_timestep
if is_within_timestep_range:
# As soon as the current timestep is within the timestep range, we start skipping attention computation.
# The following inference steps will compute the attention every `block_skip_range` steps.
should_compute_attention = pab_state.iteration > 0 and pab_state.iteration % block_skip_range == 0
pab_state.iteration += 1
print(current_timestep, is_within_timestep_range, should_compute_attention)
return not should_compute_attention
# We are still not yet in the phase of inference where skipping attention is possible without minimal quality
# loss, as described in the paper. So, the attention computation cannot be skipped
return False
hook = PyramidAttentionBroadcastHook(skip_callback=skip_callback)
add_hook_to_module(module, hook, append=True)
def _apply_pyramid_attention_broadcast_on_attention_class(
pipeline: DiffusionPipeline, name: str, module: Attention, config: PyramidAttentionBroadcastConfig
):
# Similar check as PEFT to determine if a string layer name matches a module name
is_spatial_self_attention = (
any(
f"{identifier}." in name or identifier == name for identifier in config.spatial_attention_block_identifiers
)
and config.spatial_attention_block_skip_range is not None
and not module.is_cross_attention
)
is_temporal_self_attention = (
any(
f"{identifier}." in name or identifier == name
for identifier in config.temporal_attention_block_identifiers
)
and config.temporal_attention_block_skip_range is not None
and not module.is_cross_attention
)
is_cross_attention = (
any(f"{identifier}." in name or identifier == name for identifier in config.cross_attention_block_identifiers)
and config.cross_attention_block_skip_range is not None
and not module.is_cross_attention
)
block_skip_range, timestep_skip_range = None, None
if is_spatial_self_attention:
block_skip_range = config.spatial_attention_block_skip_range
timestep_skip_range = config.spatial_attention_timestep_skip_range
elif is_temporal_self_attention:
block_skip_range = config.temporal_attention_block_skip_range
timestep_skip_range = config.temporal_attention_timestep_skip_range
elif is_cross_attention:
block_skip_range = config.cross_attention_block_skip_range
timestep_skip_range = config.cross_attention_timestep_skip_range
if block_skip_range is None or timestep_skip_range is None:
logger.warning(f"Unable to apply Pyramid Attention Broadcast to the selected layer: {name}.")
return
def current_timestep_callback():
return pipeline._current_timestep
apply_pyramid_attention_broadcast_on_module(
module, block_skip_range, timestep_skip_range, current_timestep_callback
)
class PyramidAttentionBroadcastMixin:
r"""Mixin class for [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588)."""
def _enable_pyramid_attention_broadcast(self) -> None:
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
for name, module in denoiser.named_modules():
if isinstance(module, Attention):
is_spatial_attention = (
any(x in name for x in self._pab_spatial_attn_layer_identifiers)
and self._pab_spatial_attn_skip_range is not None
and not module.is_cross_attention
)
is_temporal_attention = (
any(x in name for x in self._pab_temporal_attn_layer_identifiers)
and self._pab_temporal_attn_skip_range is not None
and not module.is_cross_attention
)
is_cross_attention = (
any(x in name for x in self._pab_cross_attn_layer_identifiers)
and self._pab_cross_attn_skip_range is not None
and module.is_cross_attention
)
if is_spatial_attention:
skip_range = self._pab_spatial_attn_skip_range
timestep_range = self._pab_spatial_attn_timestep_range
if is_temporal_attention:
skip_range = self._pab_temporal_attn_skip_range
timestep_range = self._pab_temporal_attn_timestep_range
if is_cross_attention:
skip_range = self._pab_cross_attn_skip_range
timestep_range = self._pab_cross_attn_timestep_range
if skip_range is None:
continue
# logger.debug(f"Enabling Pyramid Attention Broadcast in layer: {name}")
print(f"Enabling Pyramid Attention Broadcast in layer: {name}")
add_hook_to_module(
module,
PyramidAttentionBroadcastHook(
skip_range=skip_range,
timestep_range=timestep_range,
timestep_callback=self._pyramid_attention_broadcast_timestep_callback,
),
append=True,
)
def _disable_pyramid_attention_broadcast(self) -> None:
denoiser: nn.Module = self.transformer if hasattr(self, "transformer") else self.unet
for name, module in denoiser.named_modules():
logger.debug(f"Disabling Pyramid Attention Broadcast in layer: {name}")
remove_hook_from_module(module)
def _pyramid_attention_broadcast_timestep_callback(self):
return self._current_timestep
def enable_pyramid_attention_broadcast(
self,
spatial_attn_skip_range: Optional[int] = None,
spatial_attn_timestep_range: Tuple[int, int] = (100, 800),
temporal_attn_skip_range: Optional[int] = None,
cross_attn_skip_range: Optional[int] = None,
temporal_attn_timestep_range: Tuple[int, int] = (100, 800),
cross_attn_timestep_range: Tuple[int, int] = (100, 800),
spatial_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
temporal_attn_layer_identifiers: List[str] = ["temporal_transformer_blocks"],
cross_attn_layer_identifiers: List[str] = ["blocks", "transformer_blocks"],
) -> None:
r"""
Enable pyramid attention broadcast to speedup inference by re-using attention states and skipping computation
systematically as described in the paper: [Pyramid Attention Broadcast](https://www.arxiv.org/abs/2408.12588).
Args:
spatial_attn_skip_range (`int`, *optional*):
The attention block to execute after skipping intermediate spatial attention blocks. If set to the
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
different tolerances to how much attention computation can be reused based on the differences between
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
Setting this value to `2` is recommended for different models PAB has been experimented with.
temporal_attn_skip_range (`int`, *optional*):
The attention block to execute after skipping intermediate temporal attention blocks. If set to the
value `N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
different tolerances to how much attention computation can be reused based on the differences between
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
Setting this value to `4` is recommended for different models PAB has been experimented with.
cross_attn_skip_range (`int`, *optional*):
The attention block to execute after skipping intermediate cross attention blocks. If set to the value
`N`, `N - 1` attention blocks are skipped and every N'th block is executed. Different models have
different tolerances to how much attention computation can be reused based on the differences between
successive blocks. So, this parameter must be adjusted per model after performing experimentation.
Setting this value to `6` is recommended for different models PAB has been experimented with.
spatial_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The timestep range between which PAB will remain activated in spatial attention blocks. While
activated, PAB will re-use attention computations between inference steps.
temporal_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The timestep range between which PAB will remain activated in temporal attention blocks. While
activated, PAB will re-use attention computations between inference steps.
cross_attn_timestep_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The timestep range between which PAB will remain activated in cross attention blocks. While activated,
PAB will re-use attention computations between inference steps.
"""
if spatial_attn_timestep_range[0] > spatial_attn_timestep_range[1]:
raise ValueError(
"Expected `spatial_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
)
if temporal_attn_timestep_range[0] > temporal_attn_timestep_range[1]:
raise ValueError(
"Expected `temporal_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
)
if cross_attn_timestep_range[0] > cross_attn_timestep_range[1]:
raise ValueError(
"Expected `cross_attn_timestep_range` to be a tuple of two integers, with first value lesser or equal than second. These correspond to the min and max timestep between which PAB will be applied."
)
self._pab_spatial_attn_skip_range = spatial_attn_skip_range
self._pab_temporal_attn_skip_range = temporal_attn_skip_range
self._pab_cross_attn_skip_range = cross_attn_skip_range
self._pab_spatial_attn_timestep_range = spatial_attn_timestep_range
self._pab_temporal_attn_timestep_range = temporal_attn_timestep_range
self._pab_cross_attn_timestep_range = cross_attn_timestep_range
self._pab_spatial_attn_layer_identifiers = spatial_attn_layer_identifiers
self._pab_temporal_attn_layer_identifiers = temporal_attn_layer_identifiers
self._pab_cross_attn_layer_identifiers = cross_attn_layer_identifiers
self._pab_enabled = spatial_attn_skip_range or temporal_attn_skip_range or cross_attn_skip_range
self._enable_pyramid_attention_broadcast()
def disable_pyramid_attention_broadcast(self) -> None:
r"""Disables the pyramid attention broadcast sampling mechanism."""
self._pab_spatial_attn_skip_range = None
self._pab_temporal_attn_skip_range = None
self._pab_cross_attn_skip_range = None
self._pab_spatial_attn_timestep_range = None
self._pab_temporal_attn_timestep_range = None
self._pab_cross_attn_timestep_range = None
self._pab_spatial_attn_layer_identifiers = None
self._pab_temporal_attn_layer_identifiers = None
self._pab_cross_attn_layer_identifiers = None
self._pab_enabled = False
@property
def pyramid_attention_broadcast_enabled(self):
return hasattr(self, "_pab_enabled") and self._pab_enabled

View File

@@ -21,6 +21,7 @@ import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
@@ -59,7 +60,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
test_xformers_attention = False
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -71,7 +72,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
@@ -319,6 +320,54 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pyramid_attention_broadcast(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 4
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
assert pipe.pyramid_attention_broadcast_enabled
num_pab_processors = sum(
[
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
for processor in pipe.transformer.attn_processors.values()
]
)
assert num_pab_processors == num_layers
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
pipe.disable_pyramid_attention_broadcast()
assert not pipe.pyramid_attention_broadcast_enabled
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=0.2
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=0.2
), "Original outputs should match when PAB is disabled."
@slow
@require_torch_gpu

View File

@@ -22,6 +22,7 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
@@ -61,7 +62,7 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
)
test_xformers_attention = False
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -76,7 +77,7 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
@@ -342,6 +343,54 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pyramid_attention_broadcast(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 4
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
assert pipe.pyramid_attention_broadcast_enabled
num_pab_processors = sum(
[
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
for processor in pipe.transformer.attn_processors.values()
]
)
assert num_pab_processors == num_layers
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
pipe.disable_pyramid_attention_broadcast()
assert not pipe.pyramid_attention_broadcast_enabled
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=0.2
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=0.2
), "Original outputs should match when PAB is disabled."
@slow
@require_torch_gpu

View File

@@ -21,6 +21,7 @@ from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel, CogVideoXVideoToVideoPipeline, DDIMScheduler
from diffusers.pipelines.pyramid_attention_broadcast_utils import PyramidAttentionBroadcastAttentionProcessorWrapper
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
@@ -53,7 +54,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
)
test_xformers_attention = False
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = CogVideoXTransformer3DModel(
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings
@@ -65,7 +66,7 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
out_channels=4,
time_embed_dim=2,
text_embed_dim=32, # Must match with tiny-random-t5
num_layers=1,
num_layers=num_layers,
sample_width=2, # latent width: 2 -> final width: 16
sample_height=2, # latent height: 2 -> final height: 16
sample_frames=9, # latent frames: (9 - 1) / 4 + 1 = 3 -> final frames: 9
@@ -323,3 +324,51 @@ class CogVideoXVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
assert np.allclose(
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
), "Original outputs should match when fused QKV projections are disabled."
def test_pyramid_attention_broadcast(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 4
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
assert pipe.pyramid_attention_broadcast_enabled
num_pab_processors = sum(
[
isinstance(processor, PyramidAttentionBroadcastAttentionProcessorWrapper)
for processor in pipe.transformer.attn_processors.values()
]
)
assert num_pab_processors == num_layers
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
pipe.disable_pyramid_attention_broadcast()
assert not pipe.pyramid_attention_broadcast_enabled
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=0.2
), "PAB outputs should not differ much in specified timestep range."
assert np.allclose(
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.2
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=0.2
), "Original outputs should match when PAB is disabled."

View File

@@ -53,11 +53,11 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
required_optional_params = PipelineTesterMixin.required_optional_params
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LatteTransformer3DModel(
sample_size=8,
num_layers=1,
num_layers=num_layers,
patch_size=2,
attention_head_dim=8,
num_attention_heads=3,
@@ -264,6 +264,47 @@ class LattePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_xformers_attention_forwardGenerator_pass(self):
super()._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
def test_pyramid_attention_broadcast(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
num_layers = 4
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames # [B, F, C, H, W]
original_image_slice = frames[0, -2:, -1, -3:, -3:]
pipe.enable_pyramid_attention_broadcast(spatial_attn_skip_range=2, spatial_attn_timestep_range=(100, 800))
assert pipe.pyramid_attention_broadcast_enabled
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
frames = pipe(**inputs).frames
image_slice_pab_enabled = frames[0, -2:, -1, -3:, -3:]
pipe.disable_pyramid_attention_broadcast()
assert not pipe.pyramid_attention_broadcast_enabled
inputs = self.get_dummy_inputs(device)
frames = pipe(**inputs).frames
image_slice_pab_disabled = frames[0, -2:, -1, -3:, -3:]
# We need to use higher tolerance because we are using a random model. With a converged/trained
# model, the tolerance can be lower.
assert np.allclose(
original_image_slice, image_slice_pab_enabled, atol=0.25
), "PAB outputs should not differ much in specified timestep range."
print((image_slice_pab_disabled - image_slice_pab_enabled).abs().max())
assert np.allclose(
image_slice_pab_enabled, image_slice_pab_disabled, atol=0.25
), "Outputs, with PAB enabled, shouldn't differ much when PAB is disabled in specified timestep range."
assert np.allclose(
original_image_slice, image_slice_pab_disabled, atol=0.25
), "Original outputs should match when PAB is disabled."
@slow
@require_torch_gpu