Compare commits

..

7 Commits

Author SHA1 Message Date
Sayak Paul
8470ce3d06 Merge branch 'main' into cache-docs-fixes 2026-01-10 09:13:39 +05:30
sayakpaul
73601980c2 up 2026-01-10 09:09:44 +05:30
Sayak Paul
25795856e0 Update docs/source/en/optimization/cache.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-10 09:07:46 +05:30
Sayak Paul
d76b744ac3 Merge branch 'main' into cache-docs-fixes 2025-11-26 15:22:39 +05:30
Sayak Paul
b26867b628 Merge branch 'main' into cache-docs-fixes 2025-11-20 10:06:19 +05:30
Sayak Paul
e3f441648c Update docs/source/en/optimization/cache.md
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-11-20 10:00:46 +05:30
sayakpaul
c6cfc5ce1d polish caching docs. 2025-11-19 08:40:28 +05:30
5 changed files with 30 additions and 33 deletions

View File

@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
## FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig

View File

@@ -68,6 +68,20 @@ config = FasterCacheConfig(
pipeline.transformer.enable_cache(config)
```
## FirstBlockCache
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
pipeline = DiffusionPipeline.from_pretrained(
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
)
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
```
## TaylorSeer Cache
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
@@ -87,8 +101,7 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
)
pipe.to("cuda")
).to("cuda")
config = TaylorSeerCacheConfig(
cache_interval=5,
@@ -97,4 +110,4 @@ config = TaylorSeerCacheConfig(
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
```
```

View File

@@ -41,9 +41,11 @@ class CacheMixin:
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig]`):
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
Example:

View File

@@ -68,10 +68,6 @@ class MellonParam:
def image_latents(cls, display: str = "input") -> "MellonParam":
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
@classmethod
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
@classmethod
def image_latents_with_strength(cls) -> "MellonParam":
return cls(
@@ -93,10 +89,6 @@ class MellonParam:
def embeddings(cls, display: str = "output") -> "MellonParam":
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
@classmethod
def image_embeds(cls, display: str = "output") -> "MellonParam":
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
@classmethod
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
return cls(
@@ -194,16 +186,6 @@ class MellonParam:
"""
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
@classmethod
def image_encoder(cls) -> "MellonParam":
"""
Image Encoder model info dict.
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
"""
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
@classmethod
def unet(cls) -> "MellonParam":
"""

View File

@@ -84,7 +84,7 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_encoder"]
block_names = ["image_resize", "vae_image_encoder"]
@property
def description(self):
@@ -142,7 +142,7 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
@property
def description(self):
@@ -203,7 +203,7 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
## vae encoder
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
block_names = [
"text_encoder",
"image_encoder",
"vae_encoder",
"vae_image_encoder",
"denoise",
"decode",
]
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
]
block_names = [
"text_encoder",
"vae_encoder",
"vae_image_encoder",
"denoise",
"decode",
]
@@ -384,7 +384,7 @@ IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
@@ -401,7 +401,7 @@ FLF2V_BLOCKS = InsertableDict(
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_encoder", WanFLF2VVaeImageEncoderStep),
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
@@ -416,7 +416,7 @@ AUTO_BLOCKS = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("image_encoder", WanAutoImageEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
@@ -438,7 +438,7 @@ TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
@@ -450,7 +450,7 @@ IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_encoder", WanAutoVaeImageEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]