Compare commits

...

7 Commits

Author SHA1 Message Date
sayakpaul
e229ed1538 stricter check. 2025-06-26 06:59:42 +05:30
Sayak Paul
e2c7d17510 Merge branch 'main' into compile_utils 2025-06-26 06:55:57 +05:30
Sayak Paul
45ff4a8827 Update docs/source/en/optimization/fp16.md 2025-06-24 14:05:26 +05:30
Sayak Paul
9196f3d1ba Merge branch 'main' into compile_utils 2025-06-24 13:50:42 +05:30
github-actions[bot]
13eca6ef2d Apply style fixes 2025-06-24 08:20:04 +00:00
Animesh Jain
932914f45d Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-06-24 01:14:43 -07:00
Animesh Jain
f794d66f1e [rfc][compile] compile method for DiffusionPipeline 2025-06-24 00:47:30 -07:00
9 changed files with 101 additions and 3 deletions

View File

@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
### Regional compilation
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **810 ×**.
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
```py
# pip install -U diffusers
import torch
from diffusers import StableDiffusionXLPipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# Compile only the repeated Transformer layers inside the UNet
pipe.unet.compile_repeated_blocks(fullgraph=True)
```
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
```py
class MyUNet(ModelMixin):
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
```
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
```py
# pip install -U accelerate
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
).to("cuda")
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
### Graph breaks
@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
```
```

View File

@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
def __init__(self):
super().__init__()
@@ -1404,6 +1405,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
def compile_repeated_blocks(self, *args, **kwargs):
"""
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
module whose class name matches will be compiled.
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
`torch.compile`.
"""
repeated_blocks = getattr(self, "_repeated_blocks", None)
if not repeated_blocks:
raise ValueError(
"`_repeated_blocks` attribute is empty. "
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
)
has_compiled_region = False
for submod in self.modules():
if submod.__class__.__name__ in repeated_blocks:
submod.compile(*args, **kwargs)
has_compiled_region = True
if not has_compiled_region:
raise ValueError(
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
@classmethod
def _load_pretrained_model(
cls,

View File

@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config

View File

@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
_repeated_blocks = [
"HunyuanVideoTransformerBlock",
"HunyuanVideoSingleTransformerBlock",
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
@register_to_config
def __init__(

View File

@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -167,6 +167,7 @@ class UNet2DConditionModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["BasicTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -1935,6 +1935,27 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_torch_compile_repeated_blocks(self):
if self.model_class._repeated_blocks is None:
pytest.skip("Skipping test as `_repeated_blocks` is not set in the model class.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.compile_repeated_blocks(fullgraph=True)
recompile_limit = 1
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),
torch.no_grad(),
):
_ = model(**inputs_dict)
_ = model(**inputs_dict)
def test_compile_with_group_offloading(self):
torch._dynamo.config.cache_size_limit = 10000