mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-03 22:31:46 +08:00
Compare commits
8 Commits
security/p
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b114620d85 | ||
|
|
447e571ada | ||
|
|
5adc544b79 | ||
|
|
a05c8e9452 | ||
|
|
8070f6ec54 | ||
|
|
3e53a383e1 | ||
|
|
cf6af6b4f8 | ||
|
|
3211cd9df0 |
@@ -148,5 +148,6 @@ ComponentSpec(
|
||||
- [ ] Create pipeline class with `default_blocks_name`
|
||||
- [ ] Assemble blocks in `modular_blocks_<model>.py`
|
||||
- [ ] Wire up `__init__.py` with lazy imports
|
||||
- [ ] Add `# auto_docstring` above all assembled blocks (SequentialPipelineBlocks, AutoPipelineBlocks, etc.), run `python utils/modular_auto_docstring.py --fix_and_overwrite`, and verify the generated docstrings — all parameters should have proper descriptions with no "TODO" placeholders indicating missing definitions
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test all workflows for parity with reference
|
||||
|
||||
@@ -112,6 +112,8 @@
|
||||
title: ModularPipeline
|
||||
- local: modular_diffusers/components_manager
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/auto_docstring
|
||||
title: Auto docstring and parameter templates
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
- local: modular_diffusers/mellon
|
||||
|
||||
157
docs/source/en/modular_diffusers/auto_docstring.md
Normal file
157
docs/source/en/modular_diffusers/auto_docstring.md
Normal file
@@ -0,0 +1,157 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Auto docstring and parameter templates
|
||||
|
||||
Every [`~modular_pipelines.ModularPipelineBlocks`] has a `doc` property that is automatically generated from its `description`, `inputs`, `intermediate_outputs`, `expected_components`, and `expected_configs`. The auto docstring system keeps docstrings in sync with the block's actual interface. Parameter templates provide standardized descriptions for parameters that appear across many pipelines.
|
||||
|
||||
## Auto docstring
|
||||
|
||||
Modular pipeline blocks are composable — you can nest them, chain them in sequences, and rearrange them freely. Their docstrings follow the same pattern. When a [`~modular_pipelines.SequentialPipelineBlocks`] aggregates inputs and outputs from its sub-blocks, the documentation should update automatically without manual rewrites.
|
||||
|
||||
The `# auto_docstring` marker generates docstrings from the block's properties. Add it above a class definition to mark the class for automatic docstring generation.
|
||||
|
||||
```py
|
||||
# auto_docstring
|
||||
class FluxTextEncoderStep(SequentialPipelineBlocks):
|
||||
...
|
||||
```
|
||||
|
||||
Run the following command to generate and insert the docstrings.
|
||||
|
||||
```bash
|
||||
python utils/modular_auto_docstring.py --fix_and_overwrite
|
||||
```
|
||||
|
||||
The utility reads the block's `doc` property and inserts it as the class docstring.
|
||||
|
||||
```py
|
||||
# auto_docstring
|
||||
class FluxTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Text input processing step that standardizes text embeddings for the pipeline.
|
||||
|
||||
Inputs:
|
||||
prompt_embeds (`torch.Tensor`) *required*:
|
||||
text embeddings used to guide the image generation.
|
||||
...
|
||||
|
||||
Outputs:
|
||||
prompt_embeds (`torch.Tensor`):
|
||||
text embeddings used to guide the image generation.
|
||||
...
|
||||
"""
|
||||
```
|
||||
|
||||
You can also check without overwriting, or target a specific file or directory.
|
||||
|
||||
```bash
|
||||
# Check that all marked classes have up-to-date docstrings
|
||||
python utils/modular_auto_docstring.py
|
||||
|
||||
# Check a specific file or directory
|
||||
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/flux/
|
||||
```
|
||||
|
||||
If any marked class is missing a docstring, the check fails and lists the classes that need updating.
|
||||
|
||||
```
|
||||
Found the following # auto_docstring markers that need docstrings:
|
||||
- src/diffusers/modular_pipelines/flux/encoders.py: FluxTextEncoderStep at line 42
|
||||
|
||||
Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them.
|
||||
```
|
||||
|
||||
## Parameter templates
|
||||
|
||||
`InputParam` and `OutputParam` define a block's inputs and outputs. Create them directly or use `.template()` for standardized definitions of common parameters like `prompt`, `num_inference_steps`, or `latents`.
|
||||
|
||||
### InputParam
|
||||
|
||||
[`~modular_pipelines.InputParam`] describes a single input to a block.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `name` | `str` | Name of the parameter |
|
||||
| `type_hint` | `Any` | Type annotation (e.g., `str`, `torch.Tensor`) |
|
||||
| `default` | `Any` | Default value (if not set, parameter has no default) |
|
||||
| `required` | `bool` | Whether the parameter is required |
|
||||
| `description` | `str` | Human-readable description |
|
||||
| `kwargs_type` | `str` | Group name for related parameters (e.g., `"denoiser_input_fields"`) |
|
||||
| `metadata` | `dict` | Arbitrary additional information |
|
||||
|
||||
#### Creating InputParam directly
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import InputParam
|
||||
|
||||
InputParam(
|
||||
name="guidance_scale",
|
||||
type_hint=float,
|
||||
default=7.5,
|
||||
description="Scale for classifier-free guidance.",
|
||||
)
|
||||
```
|
||||
|
||||
#### Using a template
|
||||
|
||||
```py
|
||||
InputParam.template("prompt")
|
||||
# Equivalent to:
|
||||
# InputParam(name="prompt", type_hint=str, required=True,
|
||||
# description="The prompt or prompts to guide image generation.")
|
||||
```
|
||||
|
||||
Templates set `name`, `type_hint`, `default`, `required`, and `description` automatically. Override any field or add context with the `note` parameter.
|
||||
|
||||
```py
|
||||
# Override the default value
|
||||
InputParam.template("num_inference_steps", default=28)
|
||||
|
||||
# Add a note to the description
|
||||
InputParam.template("prompt_embeds", note="batch-expanded")
|
||||
# description becomes: "text embeddings used to guide the image generation. ... (batch-expanded)"
|
||||
```
|
||||
|
||||
### OutputParam
|
||||
|
||||
[`~modular_pipelines.OutputParam`] describes a single output from a block.
|
||||
|
||||
| Field | Type | Description |
|
||||
|---|---|---|
|
||||
| `name` | `str` | Name of the parameter |
|
||||
| `type_hint` | `Any` | Type annotation |
|
||||
| `description` | `str` | Human-readable description |
|
||||
| `kwargs_type` | `str` | Group name for related parameters |
|
||||
| `metadata` | `dict` | Arbitrary additional information |
|
||||
|
||||
#### Creating OutputParam directly
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import OutputParam
|
||||
|
||||
OutputParam(name="image_latents", type_hint=torch.Tensor, description="Encoded image latents.")
|
||||
```
|
||||
|
||||
#### Using a template
|
||||
|
||||
```py
|
||||
OutputParam.template("latents")
|
||||
|
||||
# Add a note to the description
|
||||
OutputParam.template("prompt_embeds", note="batch-expanded")
|
||||
```
|
||||
|
||||
## Available templates
|
||||
|
||||
`INPUT_PARAM_TEMPLATES` and `OUTPUT_PARAM_TEMPLATES` are defined in [modular_pipeline_utils.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/modular_pipelines/modular_pipeline_utils.py). They include common parameters like `prompt`, `image`, `num_inference_steps`, `latents`, `prompt_embeds`, and more. Refer to the source for the full list of available template names.
|
||||
|
||||
@@ -100,7 +100,7 @@ accelerate launch train_text_to_image_sdxl.py \
|
||||
|
||||
The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support SDXL training. This guide will focus on the code that is unique to the SDXL training script.
|
||||
|
||||
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
|
||||
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll create a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
|
||||
|
||||
Within the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L572) function, in addition to loading a tokenizer, the script loads a second tokenizer and text encoder because the SDXL architecture uses two of each:
|
||||
|
||||
@@ -250,5 +250,5 @@ print(f'Inference time is {time()-start} sec after compilation')
|
||||
|
||||
Congratulations on training a SDXL model! To learn more about how to use your new model, the following guides may be helpful:
|
||||
|
||||
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use it's refiner model, and the different types of micro-conditionings.
|
||||
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use its refiner model, and the different types of micro-conditionings.
|
||||
- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized SDXL model with just a few example images. These two training techniques can even be combined!
|
||||
@@ -111,7 +111,7 @@ It conditions on a monocular depth estimate of the original image.
|
||||
[Paper](https://huggingface.co/papers/2302.08113)
|
||||
|
||||
MultiDiffusion Panorama defines a new generation process over a pre-trained diffusion model. This process binds together multiple diffusion generation methods that can be readily applied to generate high quality and diverse images. Results adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.
|
||||
MultiDiffusion Panorama allows to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
|
||||
MultiDiffusion Panorama allows you to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
|
||||
|
||||
## Fine-tuning your own models
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ print(np.abs(image).sum())
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.
|
||||
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because its *state* has changed.
|
||||
|
||||
```py
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
342
examples/profiling/README.md
Normal file
342
examples/profiling/README.md
Normal file
@@ -0,0 +1,342 @@
|
||||
# Profiling a `DiffusionPipeline` with the PyTorch Profiler
|
||||
|
||||
Education materials to strategically profile pipelines to potentially improve their
|
||||
runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`,
|
||||
we often have to get rid of device-to-host (DtoH) syncs, CPU overheads, kernel launch delays, and
|
||||
graph breaks. In this context, profiling serves that purpose for us.
|
||||
|
||||
Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https://claude.com/contact-sales/claude-for-oss) support provided to us.
|
||||
|
||||
## Table of contents
|
||||
|
||||
* [Context](#context)
|
||||
* [Target pipelines](#target-pipelines)
|
||||
* [How the tooling works](#how-the-tooling-works)
|
||||
* [Verification](#verification)
|
||||
* [Interpretation of profiling traces](#interpreting-traces-in-perfetto-ui)
|
||||
* [Taking profiling-guided steps for improvements](#afterwards)
|
||||
|
||||
Jump to the "Verification" section to get started right away.
|
||||
|
||||
## Context
|
||||
|
||||
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial when using [`torch.compile`](https://docs.pytorch.org/docs/stable/generated/torch.compile.html). The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses [`torch.profiler`](https://docs.pytorch.org/docs/stable/profiler.html) with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
|
||||
|
||||
## Target Pipelines
|
||||
|
||||
We wanted to start with some of our most popular and widely-used pipelines:
|
||||
|
||||
| Pipeline | Type | Checkpoint | Steps |
|
||||
|----------|------|-----------|-------|
|
||||
| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 2 |
|
||||
| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 2 |
|
||||
| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 2 |
|
||||
| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 2 |
|
||||
| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 2 |
|
||||
|
||||
> [!NOTE]
|
||||
> We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This
|
||||
> includes using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc.
|
||||
> But we keep the number of inference steps to a bare minimum.
|
||||
|
||||
## How the Tooling Works
|
||||
|
||||
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome JSON trace.
|
||||
|
||||
### New Files
|
||||
|
||||
```bash
|
||||
profiling_utils.py # Annotation helper + profiler setup
|
||||
profiling_pipelines.py # CLI entry point with pipeline configs
|
||||
run_profiling.sh # Bulk launch runs for multiple pipelines
|
||||
```
|
||||
|
||||
### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure
|
||||
|
||||
**A) `annotate(func, name)` helper** (same pattern as flux-fast):
|
||||
|
||||
```python
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
```
|
||||
|
||||
**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline:
|
||||
|
||||
- `pipe.transformer.forward` → `"transformer_forward"`
|
||||
- `pipe.vae.decode` → `"vae_decode"` (if present)
|
||||
- `pipe.vae.encode` → `"vae_encode"` (if present)
|
||||
- `pipe.scheduler.step` → `"scheduler_step"`
|
||||
- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling)
|
||||
|
||||
This is non-invasive — it monkey-patches bound methods without modifying source.
|
||||
|
||||
**C) `PipelineProfiler` class:**
|
||||
|
||||
- `__init__(pipeline_config, output_dir, mode="eager"|"compile")`
|
||||
- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()`
|
||||
- `run()`:
|
||||
1. Warm up with 1 unannotated run
|
||||
2. Profile 1 run with `torch.profiler.profile`:
|
||||
- `activities=[CPU, CUDA]`
|
||||
- `record_shapes=True`
|
||||
- `profile_memory=True`
|
||||
- `with_stack=True`
|
||||
3. Export Chrome trace JSON
|
||||
4. Print `key_averages()` summary table (sorted by CUDA time) to stdout
|
||||
|
||||
`PipelineProfiler` also has a `benchmark()` method that can measure the total runtime of a pipeline.
|
||||
|
||||
### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs
|
||||
|
||||
**Pipeline config registry** — each entry specifies:
|
||||
|
||||
- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype`
|
||||
- `call_kwargs` with pipeline-specific defaults:
|
||||
|
||||
| Pipeline | Resolution | Frames | Steps | Extra |
|
||||
|----------|-----------|--------|-------|-------|
|
||||
| Flux | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Flux2Klein | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Wan | 480x832 | 81 | 2 | — |
|
||||
| LTX2 | 768x512 | 121 | 2 | `guidance_scale=4.0` |
|
||||
| QwenImage | 1024x1024 | — | 2 | `true_cfg_scale=4.0` |
|
||||
|
||||
All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces).
|
||||
|
||||
**CLI flags:**
|
||||
|
||||
- `--pipeline flux|flux2|wan|ltx2|qwenimage|all`
|
||||
- `--mode eager|compile|both`
|
||||
- `--output_dir profiling_results/`
|
||||
- `--num_steps N` (override, default 4)
|
||||
- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE)
|
||||
- `--compile_mode default|reduce-overhead|max-autotune`
|
||||
- `--compile_regional` flag (uses [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) to compile only the transformer forward pass instead of the full pipeline — faster compile times, ideal for iterative profiling)
|
||||
- `--compile_fullgraph` flag to ensure there are no graph breaks
|
||||
|
||||
**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary.
|
||||
|
||||
### Step 3: Known Sync Issues to Validate
|
||||
|
||||
The profiling should surface these known/suspected issues:
|
||||
|
||||
1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines.
|
||||
|
||||
2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace.
|
||||
|
||||
3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces.
|
||||
|
||||
## Verification
|
||||
|
||||
1. Run: `python examples/profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
|
||||
2. Verify `profiling_results/flux_eager.json` is produced
|
||||
3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm:
|
||||
- `transformer_forward` and `scheduler_step` annotations visible
|
||||
- CPU and CUDA timelines present
|
||||
- Stack traces visible on CPU events
|
||||
4. Run with `--mode compile`: `python examples/profiling/profiling_pipelines.py --pipeline flux --mode compile --compile_regional --num_steps 2` and compare trace for fewer/fused CUDA kernels
|
||||
|
||||
You can also use the `run_profiling.sh` script to bulk launch runs for different pipelines.
|
||||
|
||||
## Interpreting Traces in Perfetto UI
|
||||
|
||||
Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows.
|
||||
|
||||
**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> To keep the profiling iterations fast, we always use [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). The observations below would largely still apply for full model
|
||||
compilation, too.
|
||||
|
||||
### What to look for
|
||||
|
||||
**1. Gaps between CUDA kernels**
|
||||
|
||||
Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes:
|
||||
- Python overhead between ops (visible as CPU slices in the CPU row during the gap)
|
||||
- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed
|
||||
|
||||
> [!IMPORTANT]
|
||||
> No bubbles/gaps is ideal, but for small shapes (small model, small batch size, or both) some bubbles could be unavoidable.
|
||||
|
||||
**2. CPU stalls (DtoH syncs)**
|
||||
|
||||
These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler).
|
||||
|
||||
**3. Annotated regions**
|
||||
|
||||
Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly:
|
||||
- Measure how long each phase takes (click a span to see duration)
|
||||
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
|
||||
- Spot unexpected CPU work between annotated regions
|
||||
|
||||
**4. Eager vs compile comparison**
|
||||
|
||||
Open both traces side by side (two Perfetto tabs). Key differences to look for:
|
||||
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
|
||||
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
|
||||
- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks).
|
||||
- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details
|
||||
|
||||
**5. Memory timeline**
|
||||
|
||||
In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not.
|
||||
|
||||
**6. Kernel launch latency**
|
||||
|
||||
Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution:
|
||||
- The launch queue may be starved because of excessive Python work between ops
|
||||
- There may be implicit syncs forcing serialization
|
||||
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm
|
||||
|
||||
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it (there should be an arrow pointing from the CPU launch slice to the GPU kernel slice). The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
|
||||
|
||||
### Quick checklist per pipeline
|
||||
|
||||
| Question | Where to look | Healthy | Unhealthy |
|
||||
|----------|--------------|---------|-----------|
|
||||
| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us |
|
||||
| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step |
|
||||
| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time |
|
||||
| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager |
|
||||
| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU |
|
||||
| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step |
|
||||
|
||||
## What Profiling Revealed and Fixes
|
||||
|
||||
As one would expect the trace with compilation should show fewer kernel launches than its eager counterpart.
|
||||
|
||||
_(Unless otherwise specified, the traces below were obtained with **Flux2**.)_
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.03.39%E2%80%AFAM.png" alt="Image 1"><br>
|
||||
<em>Without compile</em>
|
||||
</td>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.05.06%E2%80%AFAM.png" alt="Image 2"><br>
|
||||
<em>With compile</em>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Spotting gaps between launches
|
||||
|
||||
A reasonable next step is to spot frequent gaps between kernel executions. In the compiled
|
||||
case, we don't spot any on the surface. But if we zoom in, some become apparent.
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.16.42%E2%80%AFAM.png" alt="Image 1"><br>
|
||||
<em>Very small visible gaps in between compiled regions</em>
|
||||
</td>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2010.24.34%E2%80%AFAM.png" alt="Image 2"><br>
|
||||
<em>Gaps become more visible when zoomed in</em>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
So, we provided the profile trace file (with compilation) to Claude, asked it to find the instances of
|
||||
`cudaStreamSynchronize` and `cudaDeviceSynchronize`, and to come up with some potential fixes.
|
||||
Claude came back with the following:
|
||||
|
||||
```
|
||||
Issue 1 — Gap between transformer forwards:
|
||||
- Root cause: tqdm progress bar update() calls between steps add CPU overhead (I/O, time calculations)
|
||||
- Fix: profiling/profiling_utils.py — added pipe.set_progress_bar_config(disable=True) during profiling setup.
|
||||
This eliminates the tqdm overhead from the trace. (The remaining gap from scheduler step + Python dispatch is
|
||||
inherent to eager-mode execution and should shrink significantly under torch.compile.)
|
||||
|
||||
Issue 2 — cudaStreamSynchronize during last transformer forward:
|
||||
- Root cause: _unpack_latents_with_ids() (called right after the denoising loop) computes h = torch.max(h_ids) +
|
||||
1 and w = torch.max(w_ids) + 1 on GPU tensors, then uses them as shape args for torch.zeros((h * w, ch), ...).
|
||||
This triggers an implicit .item() DtoH sync, blocking the CPU while the GPU is still finishing the last
|
||||
transformer forward's kernels.
|
||||
- Fix: Added height/width parameters to _unpack_latents_with_ids(), pre-computed from the known pixel dimensions
|
||||
at the call site.
|
||||
```
|
||||
|
||||
The changes looked reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled
|
||||
the updated pipeline. It still didn't completely eliminate the gaps as expected so, we fed that back to Claude and
|
||||
asked it to analyze what was filling those gaps now.
|
||||
|
||||
#### Discovering `cache_context` as the real bottleneck
|
||||
|
||||
Claude parsed the updated trace and broke down the CPU events in each gap between `transformer_forward` spans. The results were revealing: the dominant cost was no longer tqdm or syncs — it was `src/diffusers/hooks/hooks.py: _set_context` at **~2.7ms per call**, filled with hundreds of `named_modules()` slices.
|
||||
|
||||
Here's what was happening: under the [`cache_context`](https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py#L842) manager, there is a call to `_set_context()` upon enters and exits. It calls `named_modules()` on the entire underlying model (in this case the Flux2 Klein DiT).
|
||||
|
||||
For large models, when they are invoked iteratively like our case, it adds to the latency because it involves traversing hundreds of submodules. With 8 context switches per iteration (enter/exit for each `cache_context` call), this added up to **21.6ms** of pure Python overhead per denoising iteration.
|
||||
|
||||
The first round of fixes (`tqdm`, `_unpack_latents_with_ids`) were real issues, but they were masking this larger one. Only after removing them did the `_set_context` overhead become the clear dominant cost in the trace.
|
||||
|
||||
#### The fix — caching child registries
|
||||
|
||||
The module tree and hook registrations don't change during inference, so the `named_modules()` walk produces the same result every time. The fix was to build a list of hooked child registries once on the first call and cache it in `_child_registries_cache`. This way, the subsequent calls would return the cached list directly without
|
||||
any traversal. With the fix applied, the improvements were visible.
|
||||
|
||||
| | Before | After |
|
||||
|------------------------|------------------------------|-----------------------------|
|
||||
| `_set_context` total | 21.6ms (8 calls) | 0.0ms (8 calls) |
|
||||
| `cache_context` total | 21.7ms | 0.1ms |
|
||||
| CPU gaps | 5,523us / 8,007us / 5,508us | 158us / 2,777us / 136us |
|
||||
| Wall-clock runtime | 574.3ms (std 2.3ms) | 569.8ms (std 2.4ms) |
|
||||
|
||||
> [!NOTE]
|
||||
> The wall-clock improvement here is modest (~0.8%) because the GPU is already the bottleneck for Flux2 Klein at this resolution — the CPU finishes dispatching well before the GPU finishes executing. The CPU overhead reduction (21.6ms → 0.0ms) is hidden behind GPU execution time. These fixes become more impactful with larger batch sizes and higher resolutions, where the GPU has a deeper queue of pending kernels and any sync point causes a longer stall. The numbers were obtained on a single H100 using regional compilation with 2 inference steps and 1024x1024 resolution (`--benchmark --num_runs 5 --num_warmups 2`).
|
||||
|
||||
> [!NOTE]
|
||||
> The fixes mentioned above and below are available in [this PR](https://github.com/huggingface/diffusers/pull/13356).
|
||||
|
||||
### DtoH syncs
|
||||
|
||||
We also profiled the **Wan** model and uncovered problems related to CPU DtoH syncs. Below is an
|
||||
overview.
|
||||
|
||||
First, there was a dynamo cache lookup delay making the GPU idle as reported [in this PR](https://github.com/huggingface/diffusers/pull/11696).
|
||||
|
||||

|
||||
|
||||
Similar to the above-mentioned PR, the fix was to call `self.scheduler.set_begin_index(0)` before the denoising loop. This tells the scheduler the starting index is 0, so `_init_step_index()` skips the `nonzero().item()` (which was causing the sync) path entirely. This fix eliminated the ~2.3s GPU idle time completely.
|
||||
|
||||
The UniPC scheduler (used in Wan) also had two more sync-causing patterns in `multistep_uni_p_bh_update` and `multistep_uni_c_bh_update`:
|
||||
|
||||
1. **`torch.tensor(rks, device=device)`** where `rks` is a list containing GPU scalar tensors. `torch.tensor()` pulls each GPU value back to CPU to construct a new tensor, triggering a DtoH sync.
|
||||
|
||||
**Fix**: Replace with `torch.stack(rks)` which concatenates GPU tensors directly on the GPU — no sync needed. The appended Python float `1.0` was also changed to `torch.ones((), device=device)` so the list contains only GPU tensors.
|
||||
|
||||
2. **`torch.tensor([0.5], dtype=x.dtype, device=device)`** creates a small constant tensor from a CPU Python float. This triggers a `cudaMemcpyAsync` + `cudaStreamSynchronize` to copy the value from CPU to GPU. The sync itself is normally fast (~6us), but it forces the CPU to wait until all pending GPU kernels finish before proceeding. Under `torch.compile`, the GPU has many queued kernels, so this tiny sync balloons to 2.3s.
|
||||
|
||||
**Fix**: Replace with `torch.ones(1, dtype=x.dtype, device=device) * 0.5`. `torch.ones` allocates on GPU via `cudaMemsetAsync` (no sync), and `* 0.5` is a CUDA kernel launch (no sync). Same result, zero CPU-GPU synchronization.
|
||||
|
||||
The duration of the scheduling step before and after these fixes confirms this:
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%25202026-03-27%2520at%25206.04.06%25E2%2580%25AFPM.png" alt="Image 1"><br>
|
||||
<em>CPU<->GPU sync</em>
|
||||
</td>
|
||||
<td align="center">
|
||||
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%25202026-03-27%2520at%25206.04.29%25E2%2580%25AFPM.png" alt="Image 2"><br>
|
||||
<em>Almost no sync</em>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Notes
|
||||
|
||||
* As mentioned above, we profiled with regional compilation so it's possible that
|
||||
there are still some gaps outside the compiled regions. A full compilation
|
||||
will likely mitigate it. In case it doesn't, the above observations could
|
||||
be useful to mitigate that.
|
||||
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. CUDA Graphs can be enabled by setting the `torch.compile` mode to `"reduce-overhead"` or `"max-autotune"`.
|
||||
* Diffusers' integration of `torch.compile` is documented [here](https://huggingface.co/docs/diffusers/main/en/optimization/fp16#torchcompile).
|
||||
196
examples/profiling/profiling_pipelines.py
Normal file
196
examples/profiling/profiling_pipelines.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Profile diffusers pipelines with torch.profiler.
|
||||
|
||||
Usage:
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode both
|
||||
python profiling/profiling_pipelines.py --pipeline all --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
|
||||
|
||||
Benchmarking (wall-clock time, no profiler overhead):
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile --benchmark
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode both --benchmark --num_runs 10 --num_warmups 3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT = "A cat holding a sign that says hello world"
|
||||
|
||||
|
||||
def build_registry():
|
||||
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
|
||||
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline
|
||||
|
||||
return {
|
||||
"flux": PipelineProfilingConfig(
|
||||
name="flux",
|
||||
pipeline_cls=FluxPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"flux2": PipelineProfilingConfig(
|
||||
name="flux2",
|
||||
pipeline_cls=Flux2KleinPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"wan": PipelineProfilingConfig(
|
||||
name="wan",
|
||||
pipeline_cls=WanPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
"height": 480,
|
||||
"width": 832,
|
||||
"num_frames": 81,
|
||||
"num_inference_steps": 4,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"ltx2": PipelineProfilingConfig(
|
||||
name="ltx2",
|
||||
pipeline_cls=LTX2Pipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Lightricks/LTX-2",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
|
||||
"height": 512,
|
||||
"width": 768,
|
||||
"num_frames": 121,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"qwenimage": PipelineProfilingConfig(
|
||||
name="qwenimage",
|
||||
pipeline_cls=QwenImagePipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": " ",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"true_cfg_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
|
||||
required=True,
|
||||
help="Which pipeline to profile",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["eager", "compile", "both"],
|
||||
default="eager",
|
||||
help="Run in eager mode, compile mode, or both",
|
||||
)
|
||||
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
|
||||
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
|
||||
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
|
||||
parser.add_argument(
|
||||
"--compile_mode",
|
||||
default="default",
|
||||
choices=["default", "reduce-overhead", "max-autotune"],
|
||||
help="torch.compile mode",
|
||||
)
|
||||
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
|
||||
parser.add_argument(
|
||||
"--compile_regional",
|
||||
action="store_true",
|
||||
help="Use compile_repeated_blocks() instead of full model compile",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark",
|
||||
action="store_true",
|
||||
help="Benchmark wall-clock time instead of profiling. Uses CUDA events, no profiler overhead.",
|
||||
)
|
||||
parser.add_argument("--num_runs", type=int, default=5, help="Number of timed runs for benchmarking")
|
||||
parser.add_argument("--num_warmups", type=int, default=2, help="Number of warmup runs for benchmarking")
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_registry()
|
||||
|
||||
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
|
||||
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
|
||||
|
||||
for pipeline_name in pipeline_names:
|
||||
for mode in modes:
|
||||
config = copy.deepcopy(registry[pipeline_name])
|
||||
|
||||
# Apply overrides
|
||||
if args.num_steps is not None:
|
||||
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
|
||||
if args.full_decode:
|
||||
config.pipeline_call_kwargs["output_type"] = "pil"
|
||||
if mode == "compile":
|
||||
config.compile_kwargs = {
|
||||
"fullgraph": args.compile_fullgraph,
|
||||
"mode": args.compile_mode,
|
||||
}
|
||||
config.compile_regional = args.compile_regional
|
||||
|
||||
profiler = PipelineProfiler(config, args.output_dir)
|
||||
try:
|
||||
if args.benchmark:
|
||||
logger.info(f"Benchmarking {pipeline_name} in {mode} mode...")
|
||||
profiler.benchmark(num_runs=args.num_runs, num_warmups=args.num_warmups)
|
||||
else:
|
||||
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
|
||||
trace_file = profiler.run()
|
||||
logger.info(f"Done: {trace_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to {'benchmark' if args.benchmark else 'profile'} {pipeline_name} ({mode}): {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
215
examples/profiling/profiling_utils.py
Normal file
215
examples/profiling/profiling_utils.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.profiler
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_pipeline(pipe):
|
||||
"""Apply profiler annotations to key pipeline methods.
|
||||
|
||||
Monkey-patches bound methods so they appear as named spans in the trace.
|
||||
Non-invasive — no source modifications required.
|
||||
"""
|
||||
annotations = [
|
||||
("transformer", "forward", "transformer_forward"),
|
||||
("vae", "decode", "vae_decode"),
|
||||
("vae", "encode", "vae_encode"),
|
||||
("scheduler", "step", "scheduler_step"),
|
||||
]
|
||||
|
||||
# Annotate sub-component methods
|
||||
for component_name, method_name, label in annotations:
|
||||
component = getattr(pipe, component_name, None)
|
||||
if component is None:
|
||||
continue
|
||||
method = getattr(component, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Annotate pipeline-level methods
|
||||
if hasattr(pipe, "encode_prompt"):
|
||||
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
|
||||
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def benchmark_fn(f, *args, num_runs=5, num_warmups=2, **kwargs):
|
||||
"""Benchmark a function using CUDA events for accurate GPU timing.
|
||||
|
||||
Uses CUDA events to measure wall-clock time including GPU execution,
|
||||
without the overhead of torch.profiler. Reports mean and standard deviation
|
||||
over multiple runs.
|
||||
|
||||
Returns:
|
||||
dict with keys: mean_ms, std_ms, runs_ms (list of individual timings)
|
||||
"""
|
||||
# Warmup
|
||||
for _ in range(num_warmups):
|
||||
f(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Timed runs
|
||||
times = []
|
||||
for _ in range(num_runs):
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
start.record()
|
||||
f(*args, **kwargs)
|
||||
end.record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end))
|
||||
|
||||
mean_ms = sum(times) / len(times)
|
||||
variance = sum((t - mean_ms) ** 2 for t in times) / len(times)
|
||||
std_ms = variance**0.5
|
||||
|
||||
return {"mean_ms": mean_ms, "std_ms": std_ms, "runs_ms": times}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineProfilingConfig:
|
||||
name: str
|
||||
pipeline_cls: Any
|
||||
pipeline_init_kwargs: dict[str, Any]
|
||||
pipeline_call_kwargs: dict[str, Any]
|
||||
compile_kwargs: dict[str, Any] | None = field(default=None)
|
||||
compile_regional: bool = False
|
||||
|
||||
|
||||
class PipelineProfiler:
|
||||
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
|
||||
self.config = config
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def setup_pipeline(self, annotate=True):
|
||||
"""Load the pipeline from pretrained, optionally compile, and annotate."""
|
||||
logger.info(f"Loading pipeline: {self.config.name}")
|
||||
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
|
||||
pipe.to("cuda")
|
||||
|
||||
if self.config.compile_kwargs:
|
||||
if self.config.compile_regional:
|
||||
logger.info(
|
||||
f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}"
|
||||
)
|
||||
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
|
||||
else:
|
||||
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
|
||||
pipe.transformer.compile(**self.config.compile_kwargs)
|
||||
|
||||
# Disable tqdm progress bar to avoid CPU overhead / IO between steps
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
if annotate:
|
||||
annotate_pipeline(pipe)
|
||||
return pipe
|
||||
|
||||
def run(self):
|
||||
"""Execute the profiling run: warmup, then profile one pipeline call."""
|
||||
pipe = self.setup_pipeline()
|
||||
flush()
|
||||
|
||||
mode = "compile" if self.config.compile_kwargs else "eager"
|
||||
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
|
||||
|
||||
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
|
||||
logger.info("Running warmup...")
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
flush()
|
||||
|
||||
# Profile
|
||||
logger.info("Running profiled iteration...")
|
||||
activities = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with torch.profiler.record_function("pipeline_call"):
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
|
||||
# Export trace
|
||||
prof.export_chrome_trace(trace_file)
|
||||
logger.info(f"Chrome trace saved to: {trace_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print(f"Profile summary: {self.config.name} ({mode})")
|
||||
print("=" * 80)
|
||||
print(
|
||||
prof.key_averages().table(
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=20,
|
||||
)
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
return trace_file
|
||||
|
||||
def benchmark(self, num_runs=5, num_warmups=2):
|
||||
"""Benchmark pipeline wall-clock time without profiler overhead.
|
||||
|
||||
Uses CUDA events for accurate GPU-inclusive timing over multiple runs.
|
||||
No annotations are applied to avoid any overhead from record_function wrappers.
|
||||
Reports mean, std, and individual run times.
|
||||
"""
|
||||
pipe = self.setup_pipeline(annotate=False)
|
||||
flush()
|
||||
|
||||
mode = "compile" if self.config.compile_kwargs else "eager"
|
||||
|
||||
logger.info(f"Benchmarking {self.config.name} ({mode}): {num_warmups} warmup + {num_runs} timed runs...")
|
||||
result = benchmark_fn(pipe, num_runs=num_runs, num_warmups=num_warmups, **self.config.pipeline_call_kwargs)
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"Benchmark: {self.config.name} ({mode})")
|
||||
print("=" * 80)
|
||||
print(f" Runs: {num_runs} (after {num_warmups} warmup)")
|
||||
print(f" Mean: {result['mean_ms']:.1f} ms")
|
||||
print(f" Std: {result['std_ms']:.1f} ms")
|
||||
print(f" Individual: {', '.join(f'{t:.1f}' for t in result['runs_ms'])} ms")
|
||||
print("=" * 80)
|
||||
|
||||
# Cleanup
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
return result
|
||||
46
examples/profiling/run_profiling.sh
Executable file
46
examples/profiling/run_profiling.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
# Run profiling across all pipelines in eager and compile (regional) modes.
|
||||
#
|
||||
# Usage:
|
||||
# bash profiling/run_profiling.sh
|
||||
# bash profiling/run_profiling.sh --output_dir my_results
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
OUTPUT_DIR="profiling_results"
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
|
||||
*) echo "Unknown arg: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
NUM_STEPS=2
|
||||
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
|
||||
PIPELINES=("wan")
|
||||
MODES=("eager" "compile")
|
||||
|
||||
for pipeline in "${PIPELINES[@]}"; do
|
||||
for mode in "${MODES[@]}"; do
|
||||
echo "============================================================"
|
||||
echo "Profiling: ${pipeline} | mode: ${mode}"
|
||||
echo "============================================================"
|
||||
|
||||
COMPILE_ARGS=""
|
||||
if [ "$mode" = "compile" ]; then
|
||||
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
|
||||
fi
|
||||
|
||||
python profiling/profiling_pipelines.py \
|
||||
--pipeline "$pipeline" \
|
||||
--mode "$mode" \
|
||||
--output_dir "$OUTPUT_DIR" \
|
||||
--num_steps "$NUM_STEPS" \
|
||||
$COMPILE_ARGS
|
||||
|
||||
echo ""
|
||||
done
|
||||
done
|
||||
|
||||
echo "============================================================"
|
||||
echo "All traces saved to: ${OUTPUT_DIR}/"
|
||||
echo "============================================================"
|
||||
@@ -169,22 +169,23 @@ else:
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"TextKVCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
"apply_text_kv_cache",
|
||||
]
|
||||
)
|
||||
_import_structure["image_processor"] = [
|
||||
"IPAdapterMaskProcessor",
|
||||
"InpaintProcessor",
|
||||
"IPAdapterMaskProcessor",
|
||||
"PixArtImageProcessor",
|
||||
"VaeImageProcessor",
|
||||
"VaeImageProcessorLDM3D",
|
||||
]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -262,6 +263,7 @@ else:
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
"MultiControlNetModel",
|
||||
"NucleusMoEImageTransformer2DModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"OvisImageTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
@@ -396,6 +398,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
@@ -613,6 +616,7 @@ else:
|
||||
"MarigoldNormalsPipeline",
|
||||
"MochiPipeline",
|
||||
"MusicLDMPipeline",
|
||||
"NucleusMoEImagePipeline",
|
||||
"OmniGenPipeline",
|
||||
"OvisImagePipeline",
|
||||
"PaintByExamplePipeline",
|
||||
@@ -967,12 +971,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
apply_text_kv_cache,
|
||||
)
|
||||
from .image_processor import (
|
||||
InpaintProcessor,
|
||||
@@ -1057,6 +1063,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
MultiControlNetModel,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
OvisImageTransformer2DModel,
|
||||
ParallelConfig,
|
||||
@@ -1384,6 +1391,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MarigoldNormalsPipeline,
|
||||
MochiPipeline,
|
||||
MusicLDMPipeline,
|
||||
NucleusMoEImagePipeline,
|
||||
OmniGenPipeline,
|
||||
OvisImagePipeline,
|
||||
PaintByExamplePipeline,
|
||||
|
||||
@@ -27,3 +27,4 @@ if is_torch_available():
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
from .text_kv_cache import TextKVCacheConfig, apply_text_kv_cache
|
||||
|
||||
@@ -271,12 +271,31 @@ class HookRegistry:
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for registry in self._get_child_registries():
|
||||
registry._set_context(name)
|
||||
|
||||
def _get_child_registries(self) -> list["HookRegistry"]:
|
||||
"""Return registries of child modules, using a cached list when available.
|
||||
|
||||
The cache is built on first call and reused for subsequent calls. This avoids the cost of walking the full
|
||||
module tree via named_modules() on every _set_context call, which is significant for large models (e.g. ~2.7ms
|
||||
per call on Flux2).
|
||||
"""
|
||||
if not hasattr(self, "_child_registries_cache"):
|
||||
self._child_registries_cache = None
|
||||
|
||||
if self._child_registries_cache is not None:
|
||||
return self._child_registries_cache
|
||||
|
||||
registries = []
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
registries.append(module._diffusers_hook)
|
||||
self._child_registries_cache = registries
|
||||
return registries
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
|
||||
173
src/diffusers/hooks/text_kv_cache.py
Normal file
173
src/diffusers/hooks/text_kv_cache.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
_TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer"
|
||||
_TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextKVCacheConfig:
|
||||
"""Enable exact (lossless) text K/V caching for transformer models.
|
||||
|
||||
Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all
|
||||
steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook
|
||||
before any intermediate tensor allocations.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TextKVCacheState(BaseState):
|
||||
"""Shared state between the transformer-level and block-level hooks.
|
||||
|
||||
The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so
|
||||
that block hooks can use it as a reliable cache key across denoising steps.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.key: int | None = None
|
||||
|
||||
def reset(self):
|
||||
self.key = None
|
||||
|
||||
|
||||
class TextKVCacheBlockState(BaseState):
|
||||
"""Per-block state holding cached text key/value projections."""
|
||||
|
||||
def __init__(self):
|
||||
self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {}
|
||||
|
||||
def reset(self):
|
||||
self.kv_cache.clear()
|
||||
|
||||
|
||||
class TextKVCacheTransformerHook(ModelHook):
|
||||
"""Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm``
|
||||
and writes it to shared state for the block hooks to read."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states")
|
||||
if encoder_hidden_states is not None:
|
||||
state: TextKVCacheState = self.state_manager.get_state()
|
||||
state.key = encoder_hidden_states.data_ptr()
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class TextKVCacheBlockHook(ModelHook):
|
||||
"""Caches ``(txt_key, txt_value)`` per block per unique prompt using
|
||||
the stable cache key from the shared state."""
|
||||
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, block_state_manager: StateManager):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.block_state_manager = block_state_manager
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus
|
||||
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
if self.block_state_manager._current_context is None:
|
||||
self.block_state_manager.set_context("inference")
|
||||
|
||||
if "encoder_hidden_states" in kwargs:
|
||||
encoder_hidden_states = kwargs["encoder_hidden_states"]
|
||||
else:
|
||||
encoder_hidden_states = args[1]
|
||||
|
||||
if "image_rotary_emb" in kwargs:
|
||||
image_rotary_emb = kwargs["image_rotary_emb"]
|
||||
elif len(args) > 3:
|
||||
image_rotary_emb = args[3]
|
||||
else:
|
||||
image_rotary_emb = None
|
||||
|
||||
state: TextKVCacheState = self.state_manager.get_state()
|
||||
cache_key = state.key
|
||||
|
||||
block_state: TextKVCacheBlockState = self.block_state_manager.get_state()
|
||||
|
||||
if cache_key not in block_state.kv_cache:
|
||||
context = module.encoder_proj(encoder_hidden_states)
|
||||
|
||||
attn = module.attn
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
num_kv_heads = attn.inner_kv_dim // head_dim
|
||||
|
||||
txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1))
|
||||
txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
_, txt_freqs = image_rotary_emb
|
||||
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
block_state.kv_cache[cache_key] = (txt_key, txt_value)
|
||||
|
||||
txt_key, txt_value = block_state.kv_cache[cache_key]
|
||||
|
||||
attn_kwargs = kwargs.get("attention_kwargs") or {}
|
||||
attn_kwargs["cached_txt_key"] = txt_key
|
||||
attn_kwargs["cached_txt_value"] = txt_value
|
||||
kwargs["attention_kwargs"] = attn_kwargs
|
||||
|
||||
return self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
def reset_state(self, module: torch.nn.Module):
|
||||
self.block_state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None:
|
||||
from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock
|
||||
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(TextKVCacheState)
|
||||
|
||||
transformer_hook = TextKVCacheTransformerHook(state_manager)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK)
|
||||
|
||||
for _, submodule in module.named_modules():
|
||||
if isinstance(submodule, NucleusMoEImageTransformerBlock):
|
||||
block_state_manager = StateManager(TextKVCacheBlockState)
|
||||
hook = TextKVCacheBlockHook(state_manager, block_state_manager)
|
||||
block_registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK)
|
||||
@@ -116,6 +116,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_nucleusmoe_image"] = ["NucleusMoEImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ovis_image"] = ["OvisImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
@@ -236,6 +237,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
OvisImageTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
|
||||
@@ -423,7 +423,9 @@ def dispatch_attention_fn(
|
||||
**attention_kwargs,
|
||||
"_parallel_config": parallel_config,
|
||||
}
|
||||
if is_torch_version(">=", "2.5.0"):
|
||||
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
|
||||
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
|
||||
if _AttentionBackendRegistry._checks_enabled:
|
||||
|
||||
@@ -41,11 +41,12 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig`):
|
||||
config (`PyramidAttentionBroadcastConfig | FasterCacheConfig | FirstBlockCacheConfig | TextKVCacheConfig`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
- [`~hooks.TextKVCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
@@ -71,11 +72,13 @@ class CacheMixin:
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
apply_text_kv_cache,
|
||||
)
|
||||
|
||||
if self.is_cache_enabled:
|
||||
@@ -89,6 +92,8 @@ class CacheMixin:
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, TextKVCacheConfig):
|
||||
apply_text_kv_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -106,12 +111,14 @@ class CacheMixin:
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
TextKVCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
from ..hooks.text_kv_cache import _TEXT_KV_CACHE_BLOCK_HOOK, _TEXT_KV_CACHE_TRANSFORMER_HOOK
|
||||
|
||||
if self._cache_config is None:
|
||||
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
||||
@@ -129,6 +136,9 @@ class CacheMixin:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TextKVCacheConfig):
|
||||
registry.remove_hook(_TEXT_KV_CACHE_TRANSFORMER_HOOK, recurse=True)
|
||||
registry.remove_hook(_TEXT_KV_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
registry.remove_hook(_TAYLORSEER_CACHE_HOOK, recurse=True)
|
||||
else:
|
||||
|
||||
@@ -40,6 +40,7 @@ if is_torch_available():
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_nucleusmoe_image import NucleusMoEImageTransformer2DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_ovis_image import OvisImageTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
|
||||
@@ -0,0 +1,925 @@
|
||||
# Copyright 2025 Nucleus-Image Team, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_qwenimage.apply_rotary_emb_qwen with qwen->nucleus
|
||||
def _apply_rotary_emb_nucleus(
|
||||
x: torch.Tensor,
|
||||
freqs_cis: torch.Tensor | tuple[torch.Tensor],
|
||||
use_real: bool = True,
|
||||
use_real_unbind_dim: int = -1,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
||||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
||||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
||||
tensors contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
||||
freqs_cis (`tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: tuple of modified query tensor and key tensor with rotary embeddings.
|
||||
"""
|
||||
if use_real:
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
|
||||
if use_real_unbind_dim == -1:
|
||||
# Used for flux, cogvideox, hunyuan-dit
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
elif use_real_unbind_dim == -2:
|
||||
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
||||
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
||||
else:
|
||||
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
||||
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
return out
|
||||
else:
|
||||
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(1)
|
||||
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
|
||||
def _compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states: torch.Tensor, encoder_hidden_states_mask: torch.Tensor | None
|
||||
) -> tuple[int, torch.Tensor | None, torch.Tensor | None]:
|
||||
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
|
||||
if encoder_hidden_states_mask is None:
|
||||
return text_seq_len, None, None
|
||||
|
||||
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
|
||||
raise ValueError(
|
||||
f"`encoder_hidden_states_mask` shape {encoder_hidden_states_mask.shape} must match "
|
||||
f"(batch_size, text_seq_len)=({batch_size}, {text_seq_len})."
|
||||
)
|
||||
|
||||
if encoder_hidden_states_mask.dtype != torch.bool:
|
||||
encoder_hidden_states_mask = encoder_hidden_states_mask.to(torch.bool)
|
||||
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(
|
||||
has_active,
|
||||
active_positions.max(dim=1).values + 1,
|
||||
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
|
||||
)
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
class NucleusMoETimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, use_additional_t_cond=False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(
|
||||
num_channels=embedding_dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000
|
||||
)
|
||||
self.timestep_embedder = TimestepEmbedding(
|
||||
in_channels=embedding_dim, time_embed_dim=4 * embedding_dim, out_dim=embedding_dim
|
||||
)
|
||||
self.norm = RMSNorm(embedding_dim, eps=1e-6)
|
||||
self.use_additional_t_cond = use_additional_t_cond
|
||||
if use_additional_t_cond:
|
||||
self.addition_t_embedding = nn.Embedding(2, embedding_dim)
|
||||
|
||||
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
conditioning = timesteps_emb
|
||||
if self.use_additional_t_cond:
|
||||
if addition_t_cond is None:
|
||||
raise ValueError("When additional_t_cond is True, addition_t_cond must be provided.")
|
||||
addition_t_emb = self.addition_t_embedding(addition_t_cond)
|
||||
addition_t_emb = addition_t_emb.to(dtype=hidden_states.dtype)
|
||||
conditioning = conditioning + addition_t_emb
|
||||
|
||||
return self.norm(conditioning)
|
||||
|
||||
|
||||
class NucleusMoEEmbedRope(nn.Module):
|
||||
def __init__(self, theta: int, axes_dim: list[int], scale_rope=False):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
pos_index = torch.arange(4096)
|
||||
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
||||
self.pos_freqs = torch.cat(
|
||||
[
|
||||
self._rope_params(pos_index, self.axes_dim[0], self.theta),
|
||||
self._rope_params(pos_index, self.axes_dim[1], self.theta),
|
||||
self._rope_params(pos_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
self.neg_freqs = torch.cat(
|
||||
[
|
||||
self._rope_params(neg_index, self.axes_dim[0], self.theta),
|
||||
self._rope_params(neg_index, self.axes_dim[1], self.theta),
|
||||
self._rope_params(neg_index, self.axes_dim[2], self.theta),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
self.scale_rope = scale_rope
|
||||
|
||||
@staticmethod
|
||||
def _rope_params(index, dim, theta=10000):
|
||||
assert dim % 2 == 0
|
||||
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
||||
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_fhw: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
device: torch.device = None,
|
||||
max_txt_seq_len: int | torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
video_fhw (`tuple[int, int, int]` or `list[tuple[int, int, int]]`):
|
||||
A list of 3 integers [frame, height, width] representing the shape of the video.
|
||||
device: (`torch.device`, *optional*):
|
||||
The device on which to perform the RoPE computation.
|
||||
max_txt_seq_len (`int` or `torch.Tensor`, *optional*):
|
||||
The maximum text sequence length for RoPE computation.
|
||||
"""
|
||||
if max_txt_seq_len is None:
|
||||
raise ValueError("Either `max_txt_seq_len` must be provided.")
|
||||
|
||||
if isinstance(video_fhw, list) and len(video_fhw) > 1:
|
||||
first_fhw = video_fhw[0]
|
||||
if not all(fhw == first_fhw for fhw in video_fhw):
|
||||
logger.warning(
|
||||
"Batch inference with variable-sized images is not currently supported in NucleusMoEEmbedRope. "
|
||||
"All images in the batch should have the same dimensions (frame, height, width). "
|
||||
f"Detected sizes: {video_fhw}. Using the first image's dimensions {first_fhw} "
|
||||
"for RoPE computation, which may lead to incorrect results for other images in the batch."
|
||||
)
|
||||
|
||||
if isinstance(video_fhw, list):
|
||||
video_fhw = video_fhw[0]
|
||||
if not isinstance(video_fhw, list):
|
||||
video_fhw = [video_fhw]
|
||||
|
||||
vid_freqs = []
|
||||
for idx, fhw in enumerate(video_fhw):
|
||||
frame, height, width = fhw
|
||||
video_freq = self._compute_video_freqs(frame, height, width, idx, device)
|
||||
vid_freqs.append(video_freq)
|
||||
|
||||
max_txt_seq_len_int = int(max_txt_seq_len)
|
||||
if self.scale_rope:
|
||||
max_vid_index = torch.maximum(
|
||||
torch.tensor(height // 2, device=device, dtype=torch.long),
|
||||
torch.tensor(width // 2, device=device, dtype=torch.long),
|
||||
)
|
||||
else:
|
||||
max_vid_index = torch.maximum(
|
||||
torch.tensor(height, device=device, dtype=torch.long),
|
||||
torch.tensor(width, device=device, dtype=torch.long),
|
||||
)
|
||||
|
||||
txt_freqs = self.pos_freqs.to(device)[max_vid_index + torch.arange(max_txt_seq_len_int, device=device)]
|
||||
vid_freqs = torch.cat(vid_freqs, dim=0)
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
|
||||
|
||||
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
||||
|
||||
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
||||
if self.scale_rope:
|
||||
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
||||
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
||||
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
else:
|
||||
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
||||
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
||||
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
|
||||
class NucleusMoEAttnProcessor2_0:
|
||||
"""
|
||||
Attention processor for the NucleusMoE architecture. Image queries attend to concatenated image+text keys/values
|
||||
(cross-attention style, no text query). Supports grouped-query attention (GQA) when num_key_value_heads is set on
|
||||
the Attention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"NucleusMoEAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: torch.FloatTensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
cached_txt_key: torch.FloatTensor | None = None,
|
||||
cached_txt_value: torch.FloatTensor | None = None,
|
||||
) -> torch.FloatTensor:
|
||||
head_dim = attn.inner_dim // attn.heads
|
||||
num_kv_heads = attn.inner_kv_dim // head_dim
|
||||
num_kv_groups = attn.heads // num_kv_heads
|
||||
|
||||
img_query = attn.to_q(hidden_states).unflatten(-1, (attn.heads, -1))
|
||||
img_key = attn.to_k(hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
img_value = attn.to_v(hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
img_query = attn.norm_q(img_query)
|
||||
if attn.norm_k is not None:
|
||||
img_key = attn.norm_k(img_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
img_freqs, txt_freqs = image_rotary_emb
|
||||
img_query = _apply_rotary_emb_nucleus(img_query, img_freqs, use_real=False)
|
||||
img_key = _apply_rotary_emb_nucleus(img_key, img_freqs, use_real=False)
|
||||
|
||||
if cached_txt_key is not None and cached_txt_value is not None:
|
||||
txt_key, txt_value = cached_txt_key, cached_txt_value
|
||||
joint_key = torch.cat([img_key, txt_key], dim=1)
|
||||
joint_value = torch.cat([img_value, txt_value], dim=1)
|
||||
elif encoder_hidden_states is not None:
|
||||
txt_key = attn.add_k_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
txt_value = attn.add_v_proj(encoder_hidden_states).unflatten(-1, (num_kv_heads, -1))
|
||||
|
||||
if attn.norm_added_k is not None:
|
||||
txt_key = attn.norm_added_k(txt_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False)
|
||||
|
||||
joint_key = torch.cat([img_key, txt_key], dim=1)
|
||||
joint_value = torch.cat([img_value, txt_value], dim=1)
|
||||
else:
|
||||
joint_key = img_key
|
||||
joint_value = img_value
|
||||
|
||||
if num_kv_groups > 1:
|
||||
joint_key = joint_key.repeat_interleave(num_kv_groups, dim=2)
|
||||
joint_value = joint_value.repeat_interleave(num_kv_groups, dim=2)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
img_query,
|
||||
joint_key,
|
||||
joint_value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(img_query.dtype)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1:
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _is_moe_layer(strategy: str, layer_idx: int, num_layers: int) -> bool:
|
||||
if strategy == "leave_first_three_and_last_block_dense":
|
||||
return layer_idx >= 3 and layer_idx < num_layers - 1
|
||||
elif strategy == "leave_first_three_blocks_dense":
|
||||
return layer_idx >= 3
|
||||
elif strategy == "leave_first_block_dense":
|
||||
return layer_idx >= 1
|
||||
elif strategy == "all_moe":
|
||||
return True
|
||||
elif strategy == "all_dense":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class SwiGLUExperts(nn.Module):
|
||||
"""
|
||||
Packed SwiGLU feed-forward experts for MoE: ``gate, up = (x @ gate_up_proj).chunk(2); out = (silu(gate) * up) @
|
||||
down_proj``.
|
||||
|
||||
Gate and up projections are fused into a single weight ``gate_up_proj`` so that only two grouped matmuls are needed
|
||||
at runtime (gate+up combined, then down).
|
||||
|
||||
Weights are stored pre-transposed relative to the standard linear-layer convention so that matmuls can be issued
|
||||
without a transpose at runtime.
|
||||
|
||||
Weight shapes:
|
||||
gate_up_proj: (num_experts, hidden_size, 2 * moe_intermediate_dim) -- fused gate + up projection down_proj:
|
||||
(num_experts, moe_intermediate_dim, hidden_size) -- down projection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
moe_intermediate_dim: int,
|
||||
num_experts: int,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.moe_intermediate_dim = moe_intermediate_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.use_grouped_mm = use_grouped_mm
|
||||
|
||||
self.gate_up_proj = nn.Parameter(torch.empty(num_experts, hidden_size, 2 * moe_intermediate_dim))
|
||||
self.down_proj = nn.Parameter(torch.empty(num_experts, moe_intermediate_dim, hidden_size))
|
||||
|
||||
def _run_experts_for_loop(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
num_tokens_per_expert: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute SwiGLU MoE expert outputs using a sequential per-expert for loop.
|
||||
|
||||
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
|
||||
and so on — i.e. the layout produced by a standard token-permutation step (e.g. ``generate_permute_indices``).
|
||||
|
||||
``x`` may contain trailing padding rows appended by the permutation utility to reach a length that is a
|
||||
multiple of some alignment requirement. The padding rows are stripped before expert computation and re-appended
|
||||
as zeros so that the output shape matches ``x.shape``, keeping downstream scatter/gather indices valid.
|
||||
|
||||
.. note::
|
||||
``num_tokens_per_expert.tolist()`` synchronises the device with the host. This is acceptable for the loop
|
||||
path but means the method introduces a pipeline bubble. Use :meth:`forward` with ``use_grouped_mm=True``
|
||||
when a fully device-resident kernel is required (e.g. inside ``torch.compile``).
|
||||
|
||||
SwiGLU formula::
|
||||
|
||||
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
|
||||
|
||||
Args:
|
||||
x (Tensor): Pre-permuted input tokens of shape
|
||||
``(total_tokens_including_padding, hidden_dim)``.
|
||||
num_tokens_per_expert (Tensor): 1-D integer tensor of length
|
||||
``num_experts`` giving the number of real (non-padding) tokens assigned to each expert. Values may
|
||||
differ across experts to support load-imbalanced routing.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(total_tokens_including_padding, hidden_dim)``. Positions corresponding to padding rows
|
||||
contain zeros.
|
||||
"""
|
||||
# .tolist() triggers a host-device sync; see docstring note above.
|
||||
num_tokens_per_expert_list = num_tokens_per_expert.tolist()
|
||||
|
||||
# x may be padded to a larger buffer size by the permutation utility.
|
||||
# Track the padding count so we can restore the original buffer shape.
|
||||
num_real_tokens = sum(num_tokens_per_expert_list)
|
||||
num_padding = x.shape[0] - num_real_tokens
|
||||
|
||||
# Split the real-token prefix of x into per-expert slices (variable length).
|
||||
x_per_expert = torch.split(
|
||||
x[:num_real_tokens],
|
||||
split_size_or_sections=num_tokens_per_expert_list,
|
||||
dim=0,
|
||||
)
|
||||
|
||||
expert_outputs = []
|
||||
for expert_idx, x_expert in enumerate(x_per_expert):
|
||||
gate_up = torch.matmul(x_expert, self.gate_up_proj[expert_idx])
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
out_expert = torch.matmul(F.silu(gate) * up, self.down_proj[expert_idx])
|
||||
expert_outputs.append(out_expert)
|
||||
|
||||
# Concatenate real-token outputs, then re-append zero rows for the padding.
|
||||
out = torch.cat(expert_outputs, dim=0)
|
||||
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
|
||||
return out
|
||||
|
||||
def _run_experts_grouped_mm(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
num_tokens_per_expert: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute SwiGLU MoE expert outputs using fused grouped GEMM kernels.
|
||||
|
||||
Tokens in ``x`` must be pre-sorted so that all tokens assigned to expert 0 come first, followed by expert 1,
|
||||
and so on — the same layout required by :meth:`_run_experts_for_loop`.
|
||||
|
||||
This method is fully device-resident (no host-device sync) and is compatible with ``torch.compile``.
|
||||
|
||||
``F.grouped_mm`` is called with *exclusive end* offsets: ``offsets[k]`` is the exclusive end index of expert
|
||||
``k``'s token range in ``x`` (equivalently the inclusive start of expert ``k+1``'s range). This is the
|
||||
cumulative sum of ``num_tokens_per_expert``.
|
||||
|
||||
SwiGLU formula::
|
||||
|
||||
gate, up = (x @ gate_up_proj).chunk(2) out = (silu(gate) * up) @ down_proj
|
||||
|
||||
Args:
|
||||
x (Tensor): Pre-permuted input tokens of shape
|
||||
``(total_tokens, hidden_dim)``. No padding rows expected; ``total_tokens`` must equal
|
||||
``num_tokens_per_expert.sum()``.
|
||||
num_tokens_per_expert (Tensor): 1-D integer tensor of length
|
||||
``num_experts`` giving the number of tokens assigned to each expert.
|
||||
|
||||
Returns:
|
||||
Tensor of shape ``(total_tokens, hidden_dim)`` with dtype matching ``x``.
|
||||
"""
|
||||
offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32)
|
||||
|
||||
gate_up = F.grouped_mm(x, self.gate_up_proj, offs=offsets)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
out = F.grouped_mm(F.silu(gate) * up, self.down_proj, offs=offsets)
|
||||
|
||||
return out.type_as(x)
|
||||
|
||||
def forward(self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor) -> torch.Tensor:
|
||||
if self.use_grouped_mm:
|
||||
return self._run_experts_grouped_mm(x, num_tokens_per_expert)
|
||||
return self._run_experts_for_loop(x, num_tokens_per_expert)
|
||||
|
||||
|
||||
class NucleusMoELayer(nn.Module):
|
||||
"""
|
||||
Mixture-of-Experts layer with expert-choice routing and a shared expert.
|
||||
|
||||
Routed expert weights live in :class:`SwiGLUExperts`. The router concatenates a timestep embedding with the
|
||||
(unmodulated) hidden state to produce per-token affinity scores, then selects the top-C tokens per expert
|
||||
(expert-choice routing). A shared expert processes all tokens in parallel and its output is combined with the
|
||||
routed expert outputs via scatter-add.
|
||||
|
||||
SwiGLU expert computation is implemented by :class:`SwiGLUExperts`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
moe_intermediate_dim: int,
|
||||
num_experts: int,
|
||||
capacity_factor: float,
|
||||
use_sigmoid: bool,
|
||||
route_scale: float,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.moe_intermediate_dim = moe_intermediate_dim
|
||||
self.hidden_size = hidden_size
|
||||
self.capacity_factor = capacity_factor
|
||||
self.use_sigmoid = use_sigmoid
|
||||
self.route_scale = route_scale
|
||||
|
||||
self.gate = nn.Linear(hidden_size * 2, num_experts, bias=False)
|
||||
|
||||
self.experts = SwiGLUExperts(
|
||||
hidden_size=hidden_size,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
num_experts=num_experts,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
|
||||
self.shared_expert = FeedForward(
|
||||
dim=hidden_size,
|
||||
dim_out=hidden_size,
|
||||
inner_dim=moe_intermediate_dim,
|
||||
activation_fn="swiglu",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
hidden_states_unmodulated: torch.Tensor,
|
||||
timestep: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
bs, slen, dim = hidden_states.shape
|
||||
|
||||
if timestep is not None:
|
||||
timestep_expanded = timestep.unsqueeze(1).expand(-1, slen, -1)
|
||||
router_input = torch.cat([timestep_expanded, hidden_states_unmodulated], dim=-1)
|
||||
else:
|
||||
router_input = hidden_states_unmodulated
|
||||
|
||||
logits = self.gate(router_input)
|
||||
|
||||
if self.use_sigmoid:
|
||||
scores = torch.sigmoid(logits.float()).to(logits.dtype)
|
||||
else:
|
||||
scores = F.softmax(logits.float(), dim=-1).to(logits.dtype)
|
||||
|
||||
affinity = scores.transpose(1, 2) # (B, E, S)
|
||||
capacity = max(1, math.ceil(self.capacity_factor * slen / self.num_experts))
|
||||
|
||||
topk = torch.topk(affinity, k=capacity, dim=-1)
|
||||
top_indices = topk.indices # (B, E, C)
|
||||
gating = affinity.gather(dim=-1, index=top_indices) # (B, E, C)
|
||||
|
||||
batch_offsets = torch.arange(bs, device=hidden_states.device, dtype=torch.long).view(bs, 1, 1) * slen
|
||||
global_token_indices = (batch_offsets + top_indices).transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
||||
gating_flat = gating.transpose(0, 1).reshape(self.num_experts, -1).reshape(-1)
|
||||
|
||||
token_score_sums = torch.zeros(bs * slen, device=hidden_states.device, dtype=gating_flat.dtype)
|
||||
token_score_sums.scatter_add_(0, global_token_indices, gating_flat)
|
||||
gating_flat = gating_flat / (token_score_sums[global_token_indices] + 1e-12)
|
||||
gating_flat = gating_flat * self.route_scale
|
||||
|
||||
x_flat = hidden_states.reshape(bs * slen, dim)
|
||||
routed_input = x_flat[global_token_indices]
|
||||
|
||||
tokens_per_expert = bs * capacity
|
||||
num_tokens_per_expert = torch.full(
|
||||
(self.num_experts,),
|
||||
tokens_per_expert,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.long,
|
||||
)
|
||||
routed_output = self.experts(routed_input, num_tokens_per_expert)
|
||||
routed_output = (routed_output.float() * gating_flat.unsqueeze(-1)).to(hidden_states.dtype)
|
||||
|
||||
out = self.shared_expert(hidden_states).reshape(bs * slen, dim)
|
||||
|
||||
scatter_idx = global_token_indices.reshape(-1, 1).expand(-1, dim)
|
||||
out = out.scatter_add(dim=0, index=scatter_idx, src=routed_output)
|
||||
out = out.reshape(bs, slen, dim)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class NucleusMoEImageTransformerBlock(nn.Module):
|
||||
"""
|
||||
Single-stream DiT block with optional Mixture-of-Experts MLP. Only the image stream receives adaptive modulation;
|
||||
the text context is projected per-block and used as cross-attention keys/values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
num_key_value_heads: int | None = None,
|
||||
joint_attention_dim: int = 3584,
|
||||
qk_norm: str = "rms_norm",
|
||||
eps: float = 1e-6,
|
||||
mlp_ratio: float = 4.0,
|
||||
moe_enabled: bool = False,
|
||||
num_experts: int = 128,
|
||||
moe_intermediate_dim: int = 1344,
|
||||
capacity_factor: float = 8.0,
|
||||
use_sigmoid: bool = False,
|
||||
route_scale: float = 2.5,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.moe_enabled = moe_enabled
|
||||
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
self.encoder_proj = nn.Linear(joint_attention_dim, dim)
|
||||
|
||||
self.pre_attn_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_key_value_heads,
|
||||
dim_head=attention_head_dim,
|
||||
added_kv_proj_dim=dim,
|
||||
added_proj_bias=False,
|
||||
out_dim=dim,
|
||||
out_bias=False,
|
||||
bias=False,
|
||||
processor=NucleusMoEAttnProcessor2_0(),
|
||||
qk_norm=qk_norm,
|
||||
eps=eps,
|
||||
context_pre_only=None,
|
||||
)
|
||||
|
||||
self.pre_mlp_norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
|
||||
|
||||
if moe_enabled:
|
||||
self.img_mlp = NucleusMoELayer(
|
||||
hidden_size=dim,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
num_experts=num_experts,
|
||||
capacity_factor=capacity_factor,
|
||||
use_sigmoid=use_sigmoid,
|
||||
route_scale=route_scale,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
else:
|
||||
mlp_inner_dim = int(dim * mlp_ratio * 2 / 3) // 128 * 128
|
||||
self.img_mlp = FeedForward(
|
||||
dim=dim,
|
||||
dim_out=dim,
|
||||
inner_dim=mlp_inner_dim,
|
||||
activation_fn="swiglu",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
scale1, gate1, scale2, gate2 = self.img_mod(temb).unsqueeze(1).chunk(4, dim=-1)
|
||||
|
||||
gate1 = gate1.clamp(min=-2.0, max=2.0)
|
||||
gate2 = gate2.clamp(min=-2.0, max=2.0)
|
||||
|
||||
attn_kwargs = attention_kwargs or {}
|
||||
context = None if attn_kwargs.get("cached_txt_key") is not None else self.encoder_proj(encoder_hidden_states)
|
||||
|
||||
img_normed = self.pre_attn_norm(hidden_states)
|
||||
img_modulated = img_normed * (1 + scale1)
|
||||
|
||||
img_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=context,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate1.tanh() * img_attn_output
|
||||
|
||||
img_normed2 = self.pre_mlp_norm(hidden_states)
|
||||
img_modulated2 = img_normed2 * (1 + scale2)
|
||||
|
||||
if self.moe_enabled:
|
||||
img_mlp_output = self.img_mlp(img_modulated2, img_normed2, timestep=temb)
|
||||
else:
|
||||
img_mlp_output = self.img_mlp(img_modulated2)
|
||||
|
||||
hidden_states = hidden_states + gate2.tanh() * img_mlp_output
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
fp16_finfo = torch.finfo(torch.float16)
|
||||
hidden_states = hidden_states.clip(fp16_finfo.min, fp16_finfo.max)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NucleusMoEImageTransformer2DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
"""
|
||||
Nucleus MoE Transformer for image generation. Single-stream DiT with cross-attention to text and optional
|
||||
Mixture-of-Experts feed-forward layers.
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, defaults to `64`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, *optional*, defaults to `None`):
|
||||
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
||||
num_layers (`int`, defaults to `24`):
|
||||
The number of transformer blocks.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of dimensions to use for each attention head.
|
||||
num_attention_heads (`int`, defaults to `16`):
|
||||
The number of attention heads to use.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
The number of key/value heads for grouped-query attention. Defaults to `num_attention_heads`.
|
||||
joint_attention_dim (`int`, defaults to `3584`):
|
||||
The embedding dimension of the encoder hidden states (text).
|
||||
axes_dims_rope (`tuple[int]`, defaults to `(16, 56, 56)`):
|
||||
The dimensions to use for the rotary positional embeddings.
|
||||
mlp_ratio (`float`, defaults to `4.0`):
|
||||
Multiplier for the MLP hidden dimension in dense (non-MoE) blocks.
|
||||
moe_enabled (`bool`, defaults to `True`):
|
||||
Whether to use Mixture-of-Experts layers.
|
||||
dense_moe_strategy (`str`, defaults to ``"leave_first_three_and_last_block_dense"``):
|
||||
Strategy for choosing which layers are MoE vs dense.
|
||||
num_experts (`int`, defaults to `128`):
|
||||
Number of experts per MoE layer.
|
||||
moe_intermediate_dim (`int`, defaults to `1344`):
|
||||
Hidden dimension inside each expert.
|
||||
capacity_factors (`float | list[float]`, defaults to `8.0`):
|
||||
Expert-choice capacity factor per layer.
|
||||
use_sigmoid (`bool`, defaults to `False`):
|
||||
Use sigmoid instead of softmax for routing scores.
|
||||
route_scale (`float`, defaults to `2.5`):
|
||||
Scaling factor applied to routing weights.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["NucleusMoEImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["NucleusMoEImageTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 64,
|
||||
out_channels: int | None = None,
|
||||
num_layers: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads: int | None = None,
|
||||
joint_attention_dim: int = 3584,
|
||||
axes_dims_rope: tuple[int, int, int] = (16, 56, 56),
|
||||
mlp_ratio: float = 4.0,
|
||||
moe_enabled: bool = True,
|
||||
dense_moe_strategy: str = "leave_first_three_and_last_block_dense",
|
||||
num_experts: int = 128,
|
||||
moe_intermediate_dim: int = 1344,
|
||||
capacity_factors: float | list[float] = 8.0,
|
||||
use_sigmoid: bool = False,
|
||||
route_scale: float = 2.5,
|
||||
use_grouped_mm: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
capacity_factors = capacity_factors if isinstance(capacity_factors, list) else [capacity_factors] * num_layers
|
||||
|
||||
self.pos_embed = NucleusMoEEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
||||
|
||||
self.time_text_embed = NucleusMoETimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
||||
|
||||
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
||||
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
NucleusMoEImageTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
mlp_ratio=mlp_ratio,
|
||||
moe_enabled=moe_enabled and _is_moe_layer(dense_moe_strategy, idx, num_layers),
|
||||
num_experts=num_experts,
|
||||
moe_intermediate_dim=moe_intermediate_dim,
|
||||
capacity_factor=capacity_factors[idx],
|
||||
use_sigmoid=use_sigmoid,
|
||||
route_scale=route_scale,
|
||||
use_grouped_mm=use_grouped_mm,
|
||||
)
|
||||
for idx in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
img_shapes: tuple[int, int, int] | list[tuple[int, int, int]],
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_mask: torch.Tensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
"""
|
||||
The [`NucleusMoEImageTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
img_shapes (`list[tuple[int, int, int]]`, *optional*):
|
||||
Image shapes ``(frame, height, width)`` for RoPE computation.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Boolean mask for the encoder hidden states.
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs forwarded to the attention processor.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`].
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
|
||||
text_seq_len, _, encoder_hidden_states_mask = _compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
|
||||
block_attention_kwargs = attention_kwargs.copy() if attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([image_mask, encoder_hidden_states_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_kwargs=block_attention_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -420,6 +420,7 @@ else:
|
||||
"SkyReelsV2ImageToVideoPipeline",
|
||||
"SkyReelsV2Pipeline",
|
||||
]
|
||||
_import_structure["nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImagePipeline",
|
||||
"QwenImageImg2ImgPipeline",
|
||||
@@ -768,6 +769,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MarigoldNormalsPipeline,
|
||||
)
|
||||
from .mochi import MochiPipeline
|
||||
from .nucleusmoe_image import NucleusMoEImagePipeline
|
||||
from .omnigen import OmniGenPipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
|
||||
@@ -77,6 +77,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
|
||||
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
|
||||
from .lumina import LuminaPipeline
|
||||
from .lumina2 import Lumina2Pipeline
|
||||
from .nucleusmoe_image import NucleusMoEImagePipeline
|
||||
from .ovis_image import OvisImagePipeline
|
||||
from .pag import (
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -179,6 +180,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("helios", HeliosPipeline),
|
||||
("helios-pyramid", HeliosPyramidPipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("nucleusmoe-image", NucleusMoEImagePipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
("z-image", ZImagePipeline),
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
@@ -21,6 +23,9 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@@ -129,6 +134,13 @@ class DDIMPipeline(DiffusionPipeline):
|
||||
else:
|
||||
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
|
||||
|
||||
if not 0.0 <= eta <= 1.0:
|
||||
logger.warning(
|
||||
f"`eta` should be between 0 and 1 (inclusive), but received {eta}. "
|
||||
"A value of 0 corresponds to DDIM and 1 corresponds to DDPM. "
|
||||
"Unexpected results may occur for values outside this range."
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
|
||||
@@ -396,8 +396,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
def _unpack_latents_with_ids(
|
||||
x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
@@ -407,8 +408,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
# Use provided height/width to avoid DtoH sync from torch.max().item()
|
||||
h = height if height is not None else torch.max(h_ids) + 1
|
||||
w = width if width is not None else torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
@@ -895,7 +897,10 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
# Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item()
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
|
||||
48
src/diffusers/pipelines/nucleusmoe_image/__init__.py
Normal file
48
src/diffusers/pipelines/nucleusmoe_image/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["NucleusMoEImagePipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_nucleusmoe_image"] = ["NucleusMoEImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_nucleusmoe_image import NucleusMoEImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,644 @@
|
||||
# Copyright 2025 Nucleus-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKLQwenImage, NucleusMoEImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import NucleusMoEImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = "You are an image generation assistant. Follow the user's prompt literally. Pay careful attention to spatial layout: objects described as on the left must appear on the left, on the right on the right. Match exact object counts and assign colors to the correct objects."
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import NucleusMoEImagePipeline
|
||||
|
||||
>>> pipe = NucleusMoEImagePipeline.from_pretrained("NucleusAI/NucleusMoE-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt, num_inference_steps=50).images[0]
|
||||
>>> image.save("nucleus_moe.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class NucleusMoEImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using NucleusMoE.
|
||||
|
||||
This pipeline uses a single-stream DiT with Mixture-of-Experts feed-forward layers, cross-attention to a Qwen3-VL
|
||||
text encoder, and a flow-matching Euler discrete scheduler.
|
||||
|
||||
Args:
|
||||
transformer ([`NucleusMoEImageTransformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLQwenImage`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3VLForConditionalGeneration`]):
|
||||
Text encoder for computing prompt embeddings.
|
||||
processor ([`Qwen3VLProcessor`]):
|
||||
Processor for tokenizing text inputs.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: NucleusMoEImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLQwenImage,
|
||||
text_encoder: Qwen3VLForConditionalGeneration,
|
||||
processor: Qwen3VLProcessor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
processor=processor,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
self.default_sample_size = 128
|
||||
self.default_max_sequence_length = 1024
|
||||
self.default_return_index = -8
|
||||
|
||||
def _format_prompt(self, prompt: str, system_prompt: str | None = None) -> str:
|
||||
if system_prompt is None:
|
||||
system_prompt = DEFAULT_SYSTEM_PROMPT
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
return self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_mask: torch.Tensor | None = None,
|
||||
max_sequence_length: int | None = None,
|
||||
return_index: int | None = None,
|
||||
):
|
||||
r"""
|
||||
Encode text prompt(s) into embeddings using the Qwen3-VL text encoder.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to encode.
|
||||
device (`torch.device`, *optional*):
|
||||
Torch device for the resulting tensors.
|
||||
num_images_per_prompt (`int`, defaults to 1):
|
||||
Number of images to generate per prompt.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Skips encoding when provided.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated embeddings.
|
||||
max_sequence_length (`int`, defaults to 1024):
|
||||
Maximum token length for the encoded prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
return_index = return_index or self.default_return_index
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
formatted = [self._format_prompt(p) for p in prompt]
|
||||
|
||||
inputs = self.processor(
|
||||
text=formatted,
|
||||
padding="longest",
|
||||
pad_to_multiple_of=8,
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
).to(device=device)
|
||||
|
||||
prompt_embeds_mask = inputs.attention_mask
|
||||
|
||||
outputs = self.text_encoder(**inputs, use_cache=False, return_dict=True, output_hidden_states=True)
|
||||
prompt_embeds = outputs.hidden_states[return_index]
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.to(device=device)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.to(device=device)
|
||||
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if prompt_embeds_mask is not None and prompt_embeds_mask.all():
|
||||
prompt_embeds_mask = None
|
||||
|
||||
return prompt_embeds, prompt_embeds_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
prompt_embeds_mask=None,
|
||||
negative_prompt_embeds=None,
|
||||
negative_prompt_embeds_mask=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
max_sequence_length=None,
|
||||
return_index=None,
|
||||
):
|
||||
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} "
|
||||
f"but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, "
|
||||
f"but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError("Provide either `prompt` or `prompt_embeds`. Cannot leave both undefined.")
|
||||
elif prompt is not None and not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and "
|
||||
f"`negative_prompt_embeds`: {negative_prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if return_index is not None and abs(return_index) >= self.text_encoder.config.text_config.num_hidden_layers:
|
||||
raise ValueError(
|
||||
f"absolute value of `return_index` cannot be >= {self.text_encoder.config.text_config.num_hidden_layers} "
|
||||
f"but is {abs(return_index)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size):
|
||||
latents = latents.view(
|
||||
batch_size, num_channels_latents, height // patch_size, patch_size, width // patch_size, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
||||
latents = latents.reshape(
|
||||
batch_size, (height // patch_size) * (width // patch_size), num_channels_latents * patch_size * patch_size
|
||||
)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents(latents, height, width, patch_size, vae_scale_factor):
|
||||
batch_size, num_patches, channels = latents.shape
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
||||
latents = latents.reshape(batch_size, channels // (patch_size * patch_size), 1, height, width)
|
||||
return latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
patch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = patch_size * (int(height) // (self.vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (self.vae_scale_factor * patch_size))
|
||||
shape = (batch_size, 1, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width, patch_size)
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] = None,
|
||||
negative_prompt: str | list[str] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
max_sequence_length: int | None = None,
|
||||
return_index: int | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
prompt_embeds_mask: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds_mask: torch.Tensor | None = None,
|
||||
output_type: str | None = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
negative_prompt (`str` or `list[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, an empty string is used when
|
||||
`true_cfg_scale > 1`.
|
||||
true_cfg_scale (`float`, *optional*, defaults to 4.0):
|
||||
Classifier-free guidance scale. Values greater than 1 enable CFG.
|
||||
height (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to `self.default_sample_size * self.vae_scale_factor`):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule. If not defined, a linear schedule is used.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
|
||||
One or a list of torch generators to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for image generation.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated text embeddings.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings.
|
||||
negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
|
||||
Attention mask for pre-generated negative text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between `"pil"`, `"np"`, or `"latent"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`NucleusMoEImagePipelineOutput`] instead of a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Kwargs passed to the attention processor.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`list`, *optional*):
|
||||
Tensor inputs for the `callback_on_step_end` function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the text prompt.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`NucleusMoEImagePipelineOutput`] or `tuple`:
|
||||
[`NucleusMoEImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple` where the first element
|
||||
is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
max_sequence_length = max_sequence_length or self.default_max_sequence_length
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
negative_prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs or {}
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
has_neg_prompt = negative_prompt is not None or (
|
||||
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
|
||||
)
|
||||
do_cfg = guidance_scale > 1
|
||||
|
||||
if do_cfg and not has_neg_prompt:
|
||||
negative_prompt = [""] * batch_size
|
||||
|
||||
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_embeds_mask=prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
if do_cfg:
|
||||
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_mask=negative_prompt_embeds_mask,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
return_index=return_index,
|
||||
)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
patch_size = self.transformer.config.patch_size
|
||||
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
patch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
img_shapes = [
|
||||
(1, height // self.vae_scale_factor // patch_size, width // self.vae_scale_factor // patch_size)
|
||||
] * (batch_size * num_images_per_prompt)
|
||||
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
if self.transformer.is_cache_enabled:
|
||||
self.transformer._reset_stateful_cache()
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / self.scheduler.config.num_train_timesteps,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if do_cfg:
|
||||
neg_noise_pred = self.transformer(
|
||||
hidden_states=latents,
|
||||
timestep=timestep / self.scheduler.config.num_train_timesteps,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states_mask=negative_prompt_embeds_mask,
|
||||
img_shapes=img_shapes,
|
||||
attention_kwargs=self._attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
comb_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred)
|
||||
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
|
||||
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
|
||||
noise_pred = comb_pred * (cond_norm / noise_norm)
|
||||
|
||||
noise_pred = -noise_pred
|
||||
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
latents = self._unpack_latents(latents, height, width, patch_size, self.vae_scale_factor)
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents / latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return NucleusMoEImagePipelineOutput(images=image)
|
||||
20
src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py
Normal file
20
src/diffusers/pipelines/nucleusmoe_image/pipeline_output.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class NucleusMoEImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for NucleusMoE Image pipelines.
|
||||
|
||||
Args:
|
||||
images (`list[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: list[PIL.Image.Image] | np.ndarray
|
||||
@@ -574,6 +574,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# We set the index here to remove DtoH sync, helpful especially during compilation.
|
||||
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
if self.config.boundary_ratio is not None:
|
||||
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
||||
else:
|
||||
|
||||
@@ -903,8 +903,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
rks.append(torch.ones((), device=device))
|
||||
rks = torch.stack(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
@@ -929,13 +929,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1) # (B, K)
|
||||
# for order 2, we use a simplified version
|
||||
if order == 2:
|
||||
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5
|
||||
else:
|
||||
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
|
||||
else:
|
||||
@@ -1038,8 +1038,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
rks.append(rk)
|
||||
D1s.append((mi - m0) / rk)
|
||||
|
||||
rks.append(1.0)
|
||||
rks = torch.tensor(rks, device=device)
|
||||
rks.append(torch.ones((), device=device))
|
||||
rks = torch.stack(rks)
|
||||
|
||||
R = []
|
||||
b = []
|
||||
@@ -1064,7 +1064,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
||||
|
||||
R = torch.stack(R)
|
||||
b = torch.tensor(b, device=device)
|
||||
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
|
||||
|
||||
if len(D1s) > 0:
|
||||
D1s = torch.stack(D1s, dim=1)
|
||||
@@ -1073,7 +1073,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
||||
rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5
|
||||
else:
|
||||
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
|
||||
|
||||
|
||||
@@ -287,6 +287,21 @@ class TaylorSeerCacheConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class TextKVCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
def apply_faster_cache(*args, **kwargs):
|
||||
requires_backends(apply_faster_cache, ["torch"])
|
||||
|
||||
@@ -311,6 +326,10 @@ def apply_taylorseer_cache(*args, **kwargs):
|
||||
requires_backends(apply_taylorseer_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_text_kv_cache(*args, **kwargs):
|
||||
requires_backends(apply_text_kv_cache, ["torch"])
|
||||
|
||||
|
||||
class InpaintProcessor(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1511,6 +1530,21 @@ class MultiControlNetModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class NucleusMoEImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class OmniGenTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2567,6 +2567,21 @@ class MusicLDMPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class NucleusMoEImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class OmniGenPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -347,7 +347,17 @@ def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
compiler = getattr(torch, "compiler", None)
|
||||
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
|
||||
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
|
||||
|
||||
# Fallback for older builds where compiler.is_compiling is unavailable.
|
||||
if not is_compiling:
|
||||
dynamo = getattr(torch, "_dynamo", None)
|
||||
if dynamo is not None and hasattr(dynamo, "is_compiling"):
|
||||
is_compiling = dynamo.is_compiling()
|
||||
|
||||
if is_exporting or is_compiling:
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -13,24 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKLWan
|
||||
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
@@ -39,54 +49,40 @@ class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.T
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
image = randn_tensor(
|
||||
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
|
||||
)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def dummy_input_tiling(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (128, 128)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderKLWan."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
|
||||
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKLWan."""
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKLWan."""
|
||||
|
||||
@@ -145,3 +145,138 @@ class AutoencoderTesterMixin:
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
|
||||
|
||||
class NewAutoencoderTesterMixin:
|
||||
@staticmethod
|
||||
def _accepts_generator(model):
|
||||
model_sig = inspect.signature(model.forward)
|
||||
accepts_generator = "generator" in model_sig.parameters
|
||||
return accepts_generator
|
||||
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not hasattr(model, "use_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling, DecoderOutput):
|
||||
output_without_tiling = output_without_tiling.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_tiling, DecoderOutput):
|
||||
output_with_tiling = output_with_tiling.sample
|
||||
|
||||
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
|
||||
"VAE tiling should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling_2, DecoderOutput):
|
||||
output_without_tiling_2 = output_without_tiling_2.sample
|
||||
|
||||
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled."
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
if not hasattr(self.model_class, "enable_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
if not hasattr(model, "use_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing, DecoderOutput):
|
||||
output_without_slicing = output_without_slicing.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_slicing, DecoderOutput):
|
||||
output_with_slicing = output_with_slicing.sample
|
||||
|
||||
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
|
||||
"VAE slicing should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing_2, DecoderOutput):
|
||||
output_without_slicing_2 = output_without_slicing_2.sample
|
||||
|
||||
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
)
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import NucleusMoEImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class NucleusMoEImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return NucleusMoEImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
"moe_enabled": False,
|
||||
"capacity_factors": [8.0, 8.0],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformer(NucleusMoEImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_with_attention_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Mask out some text tokens
|
||||
mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask[:, 4:] = 0
|
||||
inputs["encoder_hidden_states_mask"] = mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_without_attention_mask(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = None
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerMemory(NucleusMoEImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerTraining(NucleusMoEImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerAttention(NucleusMoEImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerLoRA(NucleusMoEImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerLoRAHotSwap(
|
||||
NucleusMoEImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin
|
||||
):
|
||||
"""LoRA hot-swapping tests for NucleusMoE Image Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerCompile(NucleusMoEImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for NucleusMoE Image Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict:
|
||||
batch_size = 1
|
||||
in_channels = 16
|
||||
joint_attention_dim = 16
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, in_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, joint_attention_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length), dtype=torch.long).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
img_shapes = [(1, height, width)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerBitsAndBytes(NucleusMoEImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for NucleusMoE Image Transformer."""
|
||||
|
||||
|
||||
class TestNucleusMoEImageTransformerTorchAo(NucleusMoEImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for NucleusMoE Image Transformer."""
|
||||
0
tests/pipelines/nucleusmoe_image/__init__.py
Normal file
0
tests/pipelines/nucleusmoe_image/__init__.py
Normal file
337
tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py
Normal file
337
tests/pipelines/nucleusmoe_image/test_nucleusmoe_image.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration, Qwen3VLProcessor
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLQwenImage,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
NucleusMoEImagePipeline,
|
||||
NucleusMoEImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class NucleusMoEImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = NucleusMoEImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = NucleusMoEImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=16,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=16,
|
||||
axes_dims_rope=(8, 4, 4),
|
||||
moe_enabled=False,
|
||||
capacity_factors=[8.0, 8.0],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
z_dim = 4
|
||||
vae = AutoencoderKLQwenImage(
|
||||
base_dim=z_dim * 6,
|
||||
z_dim=z_dim,
|
||||
dim_mult=[1, 2, 4],
|
||||
num_res_blocks=1,
|
||||
temperal_downsample=[False, True],
|
||||
# fmt: off
|
||||
latents_mean=[0.0] * z_dim,
|
||||
latents_std=[1.0] * z_dim,
|
||||
# fmt: on
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3VLConfig(
|
||||
text_config={
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 8,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"rope_scaling": {
|
||||
"mrope_section": [1, 1, 2],
|
||||
"rope_type": "default",
|
||||
"type": "default",
|
||||
},
|
||||
"rope_theta": 1000000.0,
|
||||
"vocab_size": 151936,
|
||||
"head_dim": 8,
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 16,
|
||||
"intermediate_size": 16,
|
||||
"num_heads": 2,
|
||||
"out_channels": 16,
|
||||
},
|
||||
)
|
||||
text_encoder = Qwen3VLForConditionalGeneration(config).eval()
|
||||
processor = Qwen3VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"processor": processor,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A cat sitting on a mat",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"return_index": -1,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_true_cfg(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
inputs["guidance_scale"] = 4.0
|
||||
inputs["negative_prompt"] = "low quality"
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image[0].shape, (3, 32, 32))
|
||||
|
||||
def test_prompt_embeds(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
prompt_embeds, prompt_embeds_mask = pipe.encode_prompt(
|
||||
prompt=inputs["prompt"],
|
||||
device=device,
|
||||
max_sequence_length=inputs["max_sequence_length"],
|
||||
)
|
||||
|
||||
inputs_with_embeds = self.get_dummy_inputs(device)
|
||||
inputs_with_embeds.pop("prompt")
|
||||
inputs_with_embeds["prompt_embeds"] = prompt_embeds
|
||||
inputs_with_embeds["prompt_embeds_mask"] = prompt_embeds_mask
|
||||
|
||||
image = pipe(**inputs_with_embeds).images
|
||||
self.assertEqual(image[0].shape, (3, 32, 32))
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
# PipelineTesterMixin compares outputs with assert_mean_pixel_difference, which assumes HWC numpy/PIL layout.
|
||||
# With output_type="pt", tensors are CHW; numpy_to_pil then fails. Match QwenImage: only assert max diff.
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
|
||||
# PipelineTesterMixin only keeps components whose keys contain "text" or "tokenizer"; this pipeline also
|
||||
# needs `processor` for encode_prompt (apply_chat_template). Mirror the mixin with that key included.
|
||||
if not hasattr(self.pipeline_class, "encode_prompt"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
for key in components:
|
||||
if "text_encoder" in key and hasattr(components[key], "eval"):
|
||||
components[key].eval()
|
||||
|
||||
def _is_text_stack_component(k):
|
||||
return "text" in k or "tokenizer" in k or k == "processor"
|
||||
|
||||
components_with_text_encoders = {}
|
||||
for k in components:
|
||||
if _is_text_stack_component(k):
|
||||
components_with_text_encoders[k] = components[k]
|
||||
else:
|
||||
components_with_text_encoders[k] = None
|
||||
pipe_with_just_text_encoder = self.pipeline_class(**components_with_text_encoders)
|
||||
pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
encode_prompt_signature = inspect.signature(pipe_with_just_text_encoder.encode_prompt)
|
||||
encode_prompt_parameters = list(encode_prompt_signature.parameters.values())
|
||||
|
||||
required_params = []
|
||||
for param in encode_prompt_parameters:
|
||||
if param.name == "self" or param.name == "kwargs":
|
||||
continue
|
||||
if param.default is inspect.Parameter.empty:
|
||||
required_params.append(param.name)
|
||||
|
||||
encode_prompt_param_names = [p.name for p in encode_prompt_parameters if p.name != "self"]
|
||||
input_keys = list(inputs.keys())
|
||||
encode_prompt_inputs = {k: inputs.pop(k) for k in input_keys if k in encode_prompt_param_names}
|
||||
|
||||
pipe_call_signature = inspect.signature(pipe_with_just_text_encoder.__call__)
|
||||
pipe_call_parameters = pipe_call_signature.parameters
|
||||
|
||||
for required_param_name in required_params:
|
||||
if required_param_name not in encode_prompt_inputs:
|
||||
pipe_call_param = pipe_call_parameters.get(required_param_name, None)
|
||||
if pipe_call_param is not None and pipe_call_param.default is not inspect.Parameter.empty:
|
||||
encode_prompt_inputs[required_param_name] = pipe_call_param.default
|
||||
elif extra_required_param_value_dict is not None and isinstance(extra_required_param_value_dict, dict):
|
||||
encode_prompt_inputs[required_param_name] = extra_required_param_value_dict[required_param_name]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Required parameter '{required_param_name}' in "
|
||||
f"encode_prompt has no default in either encode_prompt or __call__."
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
encoded_prompt_outputs = pipe_with_just_text_encoder.encode_prompt(**encode_prompt_inputs)
|
||||
|
||||
ast_visitor = ReturnNameVisitor()
|
||||
encode_prompt_tree = ast_visitor.get_ast_tree(cls=self.pipeline_class)
|
||||
ast_visitor.visit(encode_prompt_tree)
|
||||
prompt_embed_kwargs = ast_visitor.return_names
|
||||
prompt_embeds_kwargs = dict(zip(prompt_embed_kwargs, encoded_prompt_outputs))
|
||||
|
||||
adapted_prompt_embeds_kwargs = {
|
||||
k: prompt_embeds_kwargs.pop(k) for k in list(prompt_embeds_kwargs.keys()) if k in pipe_call_parameters
|
||||
}
|
||||
|
||||
components_with_text_encoders = {}
|
||||
for k in components:
|
||||
if _is_text_stack_component(k):
|
||||
components_with_text_encoders[k] = None
|
||||
else:
|
||||
components_with_text_encoders[k] = components[k]
|
||||
pipe_without_text_encoders = self.pipeline_class(**components_with_text_encoders).to(torch_device)
|
||||
|
||||
pipe_without_tes_inputs = {**inputs, **adapted_prompt_embeds_kwargs}
|
||||
if (
|
||||
pipe_call_parameters.get("negative_prompt", None) is not None
|
||||
and pipe_call_parameters.get("negative_prompt").default is not None
|
||||
):
|
||||
pipe_without_tes_inputs.update({"negative_prompt": None})
|
||||
|
||||
if (
|
||||
pipe_call_parameters.get("prompt", None) is not None
|
||||
and pipe_call_parameters.get("prompt").default is inspect.Parameter.empty
|
||||
and pipe_call_parameters.get("prompt_embeds", None) is not None
|
||||
and pipe_call_parameters.get("prompt_embeds").default is None
|
||||
):
|
||||
pipe_without_tes_inputs.update({"prompt": None})
|
||||
|
||||
pipe_out = pipe_without_text_encoders(**pipe_without_tes_inputs)[0]
|
||||
|
||||
full_pipe = self.pipeline_class(**components).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
pipe_out_2 = full_pipe(**inputs)[0]
|
||||
|
||||
if isinstance(pipe_out, np.ndarray) and isinstance(pipe_out_2, np.ndarray):
|
||||
self.assertTrue(np.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
|
||||
elif isinstance(pipe_out, torch.Tensor) and isinstance(pipe_out_2, torch.Tensor):
|
||||
self.assertTrue(torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol))
|
||||
Reference in New Issue
Block a user