mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
7 Commits
modular-wa
...
anijain230
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e229ed1538 | ||
|
|
e2c7d17510 | ||
|
|
45ff4a8827 | ||
|
|
9196f3d1ba | ||
|
|
13eca6ef2d | ||
|
|
932914f45d | ||
|
|
f794d66f1e |
@@ -152,9 +152,39 @@ Compilation is slow the first time, but once compiled, it is significantly faste
|
||||
|
||||
### Regional compilation
|
||||
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
|
||||
|
||||
[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
|
||||
[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by compiling **only the small, frequently-repeated block(s)** of a model, typically a Transformer layer, enabling reuse of compiled artifacts for every subsequent occurrence.
|
||||
For many diffusion architectures this delivers the *same* runtime speed-ups as full-graph compilation yet cuts compile time by **8–10 ×**.
|
||||
|
||||
To make this effortless, [`ModelMixin`] exposes [`ModelMixin.compile_repeated_blocks`] API, a helper that wraps `torch.compile` around any sub-modules you designate as repeatable:
|
||||
|
||||
```py
|
||||
# pip install -U diffusers
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
|
||||
# Compile only the repeated Transformer layers inside the UNet
|
||||
pipe.unet.compile_repeated_blocks(fullgraph=True)
|
||||
```
|
||||
|
||||
To enable a new model with regional compilation, add a `_repeated_blocks` attribute to your model class containing the class names (as strings) of the blocks you want compiled:
|
||||
|
||||
|
||||
```py
|
||||
class MyUNet(ModelMixin):
|
||||
_repeated_blocks = ("Transformer2DModel",) # ← compiled by default
|
||||
```
|
||||
|
||||
For more examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
|
||||
|
||||
**Relation to Accelerate compile_regions** There is also a separate API in [accelerate](https://huggingface.co/docs/accelerate/index) - [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78). It takes a fully automatic approach: it walks the module, picks candidate blocks, then compiles the remaining graph separately. That hands-off experience is handy for quick experiments, but it also leaves fewer knobs when you want to fine-tune which blocks are compiled or adjust compilation flags.
|
||||
|
||||
|
||||
|
||||
```py
|
||||
# pip install -U accelerate
|
||||
@@ -167,6 +197,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
).to("cuda")
|
||||
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
|
||||
```
|
||||
`compile_repeated_blocks`, by contrast, is intentionally explicit. You list the repeated blocks once (via `_repeated_blocks`) and the helper compiles exactly those, nothing more. In practice this small dose of control hits a sweet spot for diffusion models: predictable behavior, easy reasoning about cache reuse, and still a one-liner for users.
|
||||
|
||||
|
||||
### Graph breaks
|
||||
|
||||
@@ -241,4 +273,4 @@ An input is projected into three subspaces, represented by the projection matric
|
||||
|
||||
```py
|
||||
pipeline.fuse_qkv_projections()
|
||||
```
|
||||
```
|
||||
|
||||
@@ -266,6 +266,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
_keep_in_fp32_modules = None
|
||||
_skip_layerwise_casting_patterns = None
|
||||
_supports_group_offloading = True
|
||||
_repeated_blocks = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -1404,6 +1405,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
def compile_repeated_blocks(self, *args, **kwargs):
|
||||
"""
|
||||
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
|
||||
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
|
||||
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
|
||||
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
|
||||
|
||||
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
|
||||
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
|
||||
module whose class name matches will be compiled.
|
||||
|
||||
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
|
||||
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
|
||||
`torch.compile`.
|
||||
"""
|
||||
repeated_blocks = getattr(self, "_repeated_blocks", None)
|
||||
|
||||
if not repeated_blocks:
|
||||
raise ValueError(
|
||||
"`_repeated_blocks` attribute is empty. "
|
||||
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
|
||||
)
|
||||
has_compiled_region = False
|
||||
for submod in self.modules():
|
||||
if submod.__class__.__name__ in repeated_blocks:
|
||||
submod.compile(*args, **kwargs)
|
||||
has_compiled_region = True
|
||||
|
||||
if not has_compiled_region:
|
||||
raise ValueError(
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
@@ -407,6 +407,7 @@ class ChromaTransformer2DModel(
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -227,6 +227,7 @@ class FluxTransformer2DModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
_repeated_blocks = [
|
||||
"HunyuanVideoTransformerBlock",
|
||||
"HunyuanVideoSingleTransformerBlock",
|
||||
"HunyuanVideoPatchEmbed",
|
||||
"HunyuanVideoTokenRefiner",
|
||||
]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -328,6 +328,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["LTXVideoTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -345,6 +345,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
||||
_no_split_modules = ["WanTransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -167,6 +167,7 @@ class UNet2DConditionModel(
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
|
||||
_skip_layerwise_casting_patterns = ["norm"]
|
||||
_repeated_blocks = ["BasicTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -1935,6 +1935,27 @@ class TorchCompileTesterMixin:
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
if self.model_class._repeated_blocks is None:
|
||||
pytest.skip("Skipping test as `_repeated_blocks` is not set in the model class.")
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
torch.no_grad(),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
def test_compile_with_group_offloading(self):
|
||||
torch._dynamo.config.cache_size_limit = 10000
|
||||
|
||||
|
||||
Reference in New Issue
Block a user