mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 19:07:39 +08:00
Compare commits
11 Commits
chroma-lon
...
profiling-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a410b4958c | ||
|
|
bfbaf079cd | ||
|
|
bf5131fba9 | ||
|
|
6a23a771aa | ||
|
|
96506c85d0 | ||
|
|
179fa51342 | ||
|
|
60d4148529 | ||
|
|
b2b6330a54 | ||
|
|
e4d6293b4d | ||
|
|
eddef12a54 | ||
|
|
af96109435 |
250
examples/profiling/README.md
Normal file
250
examples/profiling/README.md
Normal file
@@ -0,0 +1,250 @@
|
||||
# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler
|
||||
|
||||
Education materials to strategically profile pipelines to potentially improve their
|
||||
runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`,
|
||||
we often have to get rid of DtoH syncs, CPU overheads, kernel launch delays, and
|
||||
graph breaks. In this context, profiling serves that purpose for us.
|
||||
|
||||
Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https://claude.com/contact-sales/claude-for-oss) support provided to us.
|
||||
|
||||
## Table of contents
|
||||
|
||||
* [Context](#context)
|
||||
* [Target pipelines](#target-pipelines)
|
||||
* [Approach taken](#approach)
|
||||
* [Verification](#verification)
|
||||
* [Interpretation](#interpreting-traces-in-perfetto-ui)
|
||||
* [Taking profiling-guided steps for improvements](#afterwards)
|
||||
|
||||
Jump to the "Verification" section to get started right away.
|
||||
|
||||
## Context
|
||||
|
||||
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
|
||||
|
||||
## Target Pipelines
|
||||
|
||||
| Pipeline | Type | Checkpoint | Steps |
|
||||
|----------|------|-----------|-------|
|
||||
| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 2 |
|
||||
| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 2 |
|
||||
| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 2 |
|
||||
| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 2 |
|
||||
| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 2 |
|
||||
|
||||
> [!NOTE]
|
||||
> We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This
|
||||
> include using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc.
|
||||
> But we keep the overall running time to a bare minimum (hence 2 `num_inference_steps`).
|
||||
|
||||
## Approach
|
||||
|
||||
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace.
|
||||
|
||||
### New Files
|
||||
|
||||
```
|
||||
profiling/
|
||||
profiling_utils.py # Annotation helper + profiler setup
|
||||
profiling_pipelines.py # CLI entry point with pipeline configs
|
||||
```
|
||||
|
||||
### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure
|
||||
|
||||
**A) `annotate(func, name)` helper** (same pattern as flux-fast):
|
||||
|
||||
```python
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
```
|
||||
|
||||
**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline:
|
||||
|
||||
- `pipe.transformer.forward` → `"transformer_forward"`
|
||||
- `pipe.vae.decode` → `"vae_decode"` (if present)
|
||||
- `pipe.vae.encode` → `"vae_encode"` (if present)
|
||||
- `pipe.scheduler.step` → `"scheduler_step"`
|
||||
- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling)
|
||||
|
||||
This is non-invasive — it monkey-patches bound methods without modifying source.
|
||||
|
||||
**C) `PipelineProfiler` class:**
|
||||
|
||||
- `__init__(pipeline_config, output_dir, mode="eager"|"compile")`
|
||||
- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()`
|
||||
- `run()`:
|
||||
1. Warm up with 1 unannotated run
|
||||
2. Profile 1 run with `torch.profiler.profile`:
|
||||
- `activities=[CPU, CUDA]`
|
||||
- `record_shapes=True`
|
||||
- `profile_memory=True`
|
||||
- `with_stack=True`
|
||||
3. Export Chrome trace JSON
|
||||
4. Print `key_averages()` summary table (sorted by CUDA time) to stdout
|
||||
|
||||
### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs
|
||||
|
||||
**Pipeline config registry** — each entry specifies:
|
||||
|
||||
- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype`
|
||||
- `call_kwargs` with pipeline-specific defaults:
|
||||
|
||||
| Pipeline | Resolution | Frames | Steps | Extra |
|
||||
|----------|-----------|--------|-------|-------|
|
||||
| Flux | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Flux2Klein | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Wan | 480x832 | 81 | 2 | — |
|
||||
| LTX2 | 768x512 | 121 | 2 | `guidance_scale=4.0` |
|
||||
| QwenImage | 1024x1024 | — | 2 | `true_cfg_scale=4.0` |
|
||||
|
||||
All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces).
|
||||
|
||||
**CLI flags:**
|
||||
|
||||
- `--pipeline flux|flux2|wan|ltx2|qwenimage|all`
|
||||
- `--mode eager|compile|both`
|
||||
- `--output_dir profiling_results/`
|
||||
- `--num_steps N` (override, default 4)
|
||||
- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE)
|
||||
- `--compile_mode default|reduce-overhead|max-autotune`
|
||||
- `--compile_fullgraph` flag
|
||||
|
||||
**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary.
|
||||
|
||||
### Step 3: Known Sync Issues to Validate
|
||||
|
||||
The profiling should surface these known/suspected issues:
|
||||
|
||||
1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines.
|
||||
|
||||
2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace.
|
||||
|
||||
3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces.
|
||||
|
||||
## Verification
|
||||
|
||||
1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
|
||||
2. Verify `profiling_results/flux_eager.json` is produced
|
||||
3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm:
|
||||
- `transformer_forward` and `scheduler_step` annotations visible
|
||||
- CPU and CUDA timelines present
|
||||
- Stack traces visible on CPU events
|
||||
4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels
|
||||
|
||||
You can also use the `run_profiling.sh` script to bulk launch runs for different pipelines.
|
||||
|
||||
## Interpreting Traces in Perfetto UI
|
||||
|
||||
Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows.
|
||||
|
||||
**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows.
|
||||
|
||||
### What to look for
|
||||
|
||||
**1. Gaps between CUDA kernels**
|
||||
|
||||
Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes:
|
||||
- Python overhead between ops (visible as CPU slices in the CPU row during the gap)
|
||||
- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed
|
||||
|
||||
**2. CPU stalls (DtoH syncs)**
|
||||
|
||||
These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler).
|
||||
|
||||
**3. Annotated regions**
|
||||
|
||||
Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly:
|
||||
- Measure how long each phase takes (click a span to see duration)
|
||||
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
|
||||
- Spot unexpected CPU work between annotated regions
|
||||
|
||||
**4. Eager vs compile comparison**
|
||||
|
||||
Open both traces side by side (two Perfetto tabs). Key differences to look for:
|
||||
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
|
||||
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
|
||||
- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks).
|
||||
- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details
|
||||
|
||||
**5. Memory timeline**
|
||||
|
||||
In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not.
|
||||
|
||||
**6. Kernel launch latency**
|
||||
|
||||
Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution:
|
||||
- The launch queue may be starved because of excessive Python work between ops
|
||||
- There may be implicit syncs forcing serialization
|
||||
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm
|
||||
|
||||
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
|
||||
|
||||
### Quick checklist per pipeline
|
||||
|
||||
| Question | Where to look | Healthy | Unhealthy |
|
||||
|----------|--------------|---------|-----------|
|
||||
| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us |
|
||||
| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step |
|
||||
| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time |
|
||||
| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager |
|
||||
| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU |
|
||||
| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step |
|
||||
|
||||
## Afterwards
|
||||
|
||||
To keep the profiling iterations fast, we always used [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show
|
||||
fewer kernel launches than its eager counterpart:
|
||||
|
||||
TODO: show traces
|
||||
|
||||
_(The traces above were obtained with Flux2.)_
|
||||
|
||||
### Spotting gaps between launches
|
||||
|
||||
Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled
|
||||
case, we don't spot any on the surface. But if we zone in, some become apparent.
|
||||
|
||||
TODO: show gaps in a compile trace
|
||||
|
||||
So, we provided the profile trace (with compilation) to Claude, asked it to find the instances of
|
||||
"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes.
|
||||
Claude came back pretty strong:
|
||||
|
||||
```
|
||||
Issue 1 — Gap between transformer forwards:
|
||||
- Root cause: tqdm progress bar update() calls between steps add CPU overhead (I/O, time calculations)
|
||||
- Fix: profiling/profiling_utils.py — added pipe.set_progress_bar_config(disable=True) during profiling setup.
|
||||
This eliminates the tqdm overhead from the trace. (The remaining gap from scheduler step + Python dispatch is
|
||||
inherent to eager-mode execution and should shrink significantly under torch.compile.)
|
||||
|
||||
Issue 2 — cudaStreamSynchronize during last transformer forward:
|
||||
- Root cause: _unpack_latents_with_ids() (called right after the denoising loop) computes h = torch.max(h_ids) +
|
||||
1 and w = torch.max(w_ids) + 1 on GPU tensors, then uses them as shape args for torch.zeros((h * w, ch), ...).
|
||||
This triggers an implicit .item() DtoH sync, blocking the CPU while the GPU is still finishing the last
|
||||
transformer forward's kernels.
|
||||
- Fix: Added height/width parameters to _unpack_latents_with_ids(), pre-computed from the known pixel dimensions
|
||||
at the call site.
|
||||
```
|
||||
|
||||
It still didn't eliminate the gaps as expected so, we fed that back to Claude and it spotted
|
||||
something more crucial. TODO: caching context fix.
|
||||
|
||||
With the fix applied, the improvements were visible:
|
||||
|
||||
TODO: show before and after trace
|
||||
|
||||
Before:
|
||||
|
||||
- `_set_context` total: 21.6ms (8 calls)
|
||||
- cache_context total: 21.7ms
|
||||
- CPU gaps: 5,523us / 8,007us / 5,508us
|
||||
|
||||
After:
|
||||
- `_set_context` total: 0.0ms (8 calls)
|
||||
- cache_context total: 0.1ms
|
||||
- CPU gaps: 158us / 2,777us / 136us
|
||||
182
examples/profiling/profiling_pipelines.py
Normal file
182
examples/profiling/profiling_pipelines.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Profile diffusers pipelines with torch.profiler.
|
||||
|
||||
Usage:
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode both
|
||||
python profiling/profiling_pipelines.py --pipeline all --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT = "A cat holding a sign that says hello world"
|
||||
|
||||
|
||||
def build_registry():
|
||||
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
|
||||
from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline
|
||||
|
||||
return {
|
||||
"flux": PipelineProfilingConfig(
|
||||
name="flux",
|
||||
pipeline_cls=FluxPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"flux2": PipelineProfilingConfig(
|
||||
name="flux2",
|
||||
pipeline_cls=Flux2KleinPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"wan": PipelineProfilingConfig(
|
||||
name="wan",
|
||||
pipeline_cls=WanPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
"height": 480,
|
||||
"width": 832,
|
||||
"num_frames": 81,
|
||||
"num_inference_steps": 4,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"ltx2": PipelineProfilingConfig(
|
||||
name="ltx2",
|
||||
pipeline_cls=LTX2Pipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Lightricks/LTX-2",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
|
||||
"height": 512,
|
||||
"width": 768,
|
||||
"num_frames": 121,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"qwenimage": PipelineProfilingConfig(
|
||||
name="qwenimage",
|
||||
pipeline_cls=QwenImagePipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": " ",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"true_cfg_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
|
||||
required=True,
|
||||
help="Which pipeline to profile",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["eager", "compile", "both"],
|
||||
default="eager",
|
||||
help="Run in eager mode, compile mode, or both",
|
||||
)
|
||||
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
|
||||
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
|
||||
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
|
||||
parser.add_argument(
|
||||
"--compile_mode",
|
||||
default="default",
|
||||
choices=["default", "reduce-overhead", "max-autotune"],
|
||||
help="torch.compile mode",
|
||||
)
|
||||
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
|
||||
parser.add_argument(
|
||||
"--compile_regional",
|
||||
action="store_true",
|
||||
help="Use compile_repeated_blocks() instead of full model compile",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_registry()
|
||||
|
||||
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
|
||||
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
|
||||
|
||||
for pipeline_name in pipeline_names:
|
||||
for mode in modes:
|
||||
config = copy.deepcopy(registry[pipeline_name])
|
||||
|
||||
# Apply overrides
|
||||
if args.num_steps is not None:
|
||||
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
|
||||
if args.full_decode:
|
||||
config.pipeline_call_kwargs["output_type"] = "pil"
|
||||
if mode == "compile":
|
||||
config.compile_kwargs = {
|
||||
"fullgraph": args.compile_fullgraph,
|
||||
"mode": args.compile_mode,
|
||||
}
|
||||
config.compile_regional = args.compile_regional
|
||||
|
||||
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
|
||||
profiler = PipelineProfiler(config, args.output_dir)
|
||||
try:
|
||||
trace_file = profiler.run()
|
||||
logger.info(f"Done: {trace_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
examples/profiling/profiling_utils.py
Normal file
146
examples/profiling/profiling_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.profiler
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_pipeline(pipe):
|
||||
"""Apply profiler annotations to key pipeline methods.
|
||||
|
||||
Monkey-patches bound methods so they appear as named spans in the trace.
|
||||
Non-invasive — no source modifications required.
|
||||
"""
|
||||
annotations = [
|
||||
("transformer", "forward", "transformer_forward"),
|
||||
("vae", "decode", "vae_decode"),
|
||||
("vae", "encode", "vae_encode"),
|
||||
("scheduler", "step", "scheduler_step"),
|
||||
]
|
||||
|
||||
# Annotate sub-component methods
|
||||
for component_name, method_name, label in annotations:
|
||||
component = getattr(pipe, component_name, None)
|
||||
if component is None:
|
||||
continue
|
||||
method = getattr(component, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Annotate pipeline-level methods
|
||||
if hasattr(pipe, "encode_prompt"):
|
||||
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
|
||||
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineProfilingConfig:
|
||||
name: str
|
||||
pipeline_cls: Any
|
||||
pipeline_init_kwargs: dict[str, Any]
|
||||
pipeline_call_kwargs: dict[str, Any]
|
||||
compile_kwargs: dict[str, Any] | None = field(default=None)
|
||||
compile_regional: bool = False
|
||||
|
||||
|
||||
class PipelineProfiler:
|
||||
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
|
||||
self.config = config
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def setup_pipeline(self):
|
||||
"""Load the pipeline from pretrained, optionally compile, and annotate."""
|
||||
logger.info(f"Loading pipeline: {self.config.name}")
|
||||
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
|
||||
pipe.to("cuda")
|
||||
|
||||
if self.config.compile_kwargs:
|
||||
if self.config.compile_regional:
|
||||
logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}")
|
||||
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
|
||||
else:
|
||||
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
|
||||
pipe.transformer.compile(**self.config.compile_kwargs)
|
||||
|
||||
# Disable tqdm progress bar to avoid CPU overhead / IO between steps
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
annotate_pipeline(pipe)
|
||||
return pipe
|
||||
|
||||
def run(self):
|
||||
"""Execute the profiling run: warmup, then profile one pipeline call."""
|
||||
pipe = self.setup_pipeline()
|
||||
flush()
|
||||
|
||||
mode = "compile" if self.config.compile_kwargs else "eager"
|
||||
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
|
||||
|
||||
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
|
||||
logger.info("Running warmup...")
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
flush()
|
||||
|
||||
# Profile
|
||||
logger.info("Running profiled iteration...")
|
||||
activities = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with torch.profiler.record_function("pipeline_call"):
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
|
||||
# Export trace
|
||||
prof.export_chrome_trace(trace_file)
|
||||
logger.info(f"Chrome trace saved to: {trace_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print(f"Profile summary: {self.config.name} ({mode})")
|
||||
print("=" * 80)
|
||||
print(
|
||||
prof.key_averages().table(
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=20,
|
||||
)
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
return trace_file
|
||||
46
examples/profiling/run_profiling.sh
Executable file
46
examples/profiling/run_profiling.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
# Run profiling across all pipelines in eager and compile (regional) modes.
|
||||
#
|
||||
# Usage:
|
||||
# bash profiling/run_profiling.sh
|
||||
# bash profiling/run_profiling.sh --output_dir my_results
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
OUTPUT_DIR="profiling_results"
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
|
||||
*) echo "Unknown arg: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
NUM_STEPS=2
|
||||
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
|
||||
PIPELINES=("flux2")
|
||||
MODES=("eager" "compile")
|
||||
|
||||
for pipeline in "${PIPELINES[@]}"; do
|
||||
for mode in "${MODES[@]}"; do
|
||||
echo "============================================================"
|
||||
echo "Profiling: ${pipeline} | mode: ${mode}"
|
||||
echo "============================================================"
|
||||
|
||||
COMPILE_ARGS=""
|
||||
if [ "$mode" = "compile" ]; then
|
||||
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
|
||||
fi
|
||||
|
||||
python profiling/profiling_pipelines.py \
|
||||
--pipeline "$pipeline" \
|
||||
--mode "$mode" \
|
||||
--output_dir "$OUTPUT_DIR" \
|
||||
--num_steps "$NUM_STEPS" \
|
||||
$COMPILE_ARGS
|
||||
|
||||
echo ""
|
||||
done
|
||||
done
|
||||
|
||||
echo "============================================================"
|
||||
echo "All traces saved to: ${OUTPUT_DIR}/"
|
||||
echo "============================================================"
|
||||
@@ -271,12 +271,31 @@ class HookRegistry:
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for registry in self._get_child_registries():
|
||||
registry._set_context(name)
|
||||
|
||||
def _get_child_registries(self) -> list["HookRegistry"]:
|
||||
"""Return registries of child modules, using a cached list when available.
|
||||
|
||||
The cache is built on first call and reused for subsequent calls. This avoids the cost of walking the full
|
||||
module tree via named_modules() on every _set_context call, which is significant for large models (e.g. ~2.7ms
|
||||
per call on Flux2).
|
||||
"""
|
||||
if not hasattr(self, "_child_registries_cache"):
|
||||
self._child_registries_cache = None
|
||||
|
||||
if self._child_registries_cache is not None:
|
||||
return self._child_registries_cache
|
||||
|
||||
registries = []
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
registries.append(module._diffusers_hook)
|
||||
self._child_registries_cache = registries
|
||||
return registries
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
|
||||
@@ -397,7 +397,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
def _unpack_latents_with_ids(
|
||||
x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
@@ -407,8 +409,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
# Use provided height/width to avoid DtoH sync from torch.max().item()
|
||||
h = height if height is not None else torch.max(h_ids) + 1
|
||||
w = width if width is not None else torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
@@ -895,7 +898,10 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
# Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item()
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
|
||||
@@ -13,31 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ChromaTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_chroma_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
def create_chroma_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -58,8 +50,11 @@ def create_chroma_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=model.config["pooled_projection_dim"],
|
||||
@@ -78,36 +73,53 @@ def create_chroma_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
)
|
||||
|
||||
del sd
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class ChromaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return ChromaTransformer2DModel
|
||||
class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ChromaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.7, 0.7]
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -121,35 +133,11 @@ class ChromaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"approximator_layers": 1,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestChromaTransformer(ChromaTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -157,11 +145,12 @@ class TestChromaTransformer(ChromaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
assert text_ids_3d.ndim == 3
|
||||
assert image_ids_3d.ndim == 3
|
||||
assert text_ids_3d.ndim == 3, "text_ids_3d should be a 3d tensor"
|
||||
assert image_ids_3d.ndim == 3, "img_ids_3d should be a 3d tensor"
|
||||
|
||||
inputs_dict["txt_ids"] = text_ids_3d
|
||||
inputs_dict["img_ids"] = image_ids_3d
|
||||
@@ -169,59 +158,26 @@ class TestChromaTransformer(ChromaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestChromaTransformerTraining(ChromaTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"ChromaTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestChromaTransformerCompile(ChromaTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
class ChromaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = ChromaTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class TestChromaTransformerIPAdapter(ChromaTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterJointAttnProcessor2_0
|
||||
class ChromaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = ChromaTransformer2DModel
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_chroma_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestChromaTransformerLoRA(ChromaTransformerTesterConfig, LoraTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestChromaTransformerLoRAHotSwap(ChromaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return ChromaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -13,50 +13,61 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import HiDreamImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class HiDreamTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return HiDreamImageTransformer2DModel
|
||||
class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = HiDreamImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
model_split_percents = [0.8, 0.8, 0.9]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = 32
|
||||
embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8
|
||||
sequence_length = 8
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.8, 0.9]
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length, embedding_dim_t5)).to(torch_device)
|
||||
encoder_hidden_states_llama3 = torch.randn((batch_size, batch_size, sequence_length, embedding_dim_llama)).to(
|
||||
torch_device
|
||||
)
|
||||
pooled_embeds = torch.randn((batch_size, embedding_dim_pooled)).to(torch_device)
|
||||
timesteps = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states_t5": encoder_hidden_states_t5,
|
||||
"encoder_hidden_states_llama3": encoder_hidden_states_llama3,
|
||||
"pooled_embeds": pooled_embeds,
|
||||
"timesteps": timesteps,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
@@ -71,43 +82,15 @@ class HiDreamTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": (4, 2, 2),
|
||||
"max_resolution": (32, 32),
|
||||
"llama_layers": (0, 1),
|
||||
"force_inference_output": True,
|
||||
"force_inference_output": True, # TODO: as we don't implement MoE loss in training tests.
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = width = 32
|
||||
embedding_dim_t5, embedding_dim_llama, embedding_dim_pooled = 8, 4, 8
|
||||
sequence_length = 8
|
||||
@unittest.skip("HiDreamImageTransformer2DModel uses a dedicated attention processor. This test doesn't apply")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
pass
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states_t5": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim_t5), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states_llama3": randn_tensor(
|
||||
(batch_size, batch_size, sequence_length, embedding_dim_llama),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pooled_embeds": randn_tensor(
|
||||
(batch_size, embedding_dim_pooled), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timesteps": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestHiDreamTransformer(HiDreamTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestHiDreamTransformerTraining(HiDreamTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"HiDreamImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestHiDreamTransformerCompile(HiDreamTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import LongCatImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LongCatImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return LongCatImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 32,
|
||||
"axes_dims_rope": [4, 4, 8],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
"guidance": torch.tensor([3.5]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestLongCatImageTransformer(LongCatImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestLongCatImageTransformerTraining(LongCatImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LongCatImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestLongCatImageTransformerCompile(LongCatImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
Reference in New Issue
Block a user