mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 19:07:39 +08:00
Compare commits
16 Commits
bria-test-
...
profiling-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a410b4958c | ||
|
|
bfbaf079cd | ||
|
|
bf5131fba9 | ||
|
|
6a23a771aa | ||
|
|
96506c85d0 | ||
|
|
179fa51342 | ||
|
|
60d4148529 | ||
|
|
b2b6330a54 | ||
|
|
e4d6293b4d | ||
|
|
eddef12a54 | ||
|
|
af96109435 | ||
|
|
b757035df6 | ||
|
|
41e1003316 | ||
|
|
85ffcf1db2 | ||
|
|
cbf4d9a3c3 | ||
|
|
426daabad9 |
11
.ai/review-rules.md
Normal file
11
.ai/review-rules.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# PR Review Rules
|
||||
|
||||
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
|
||||
|
||||
Before reviewing, read and apply the guidelines in:
|
||||
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
|
||||
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
|
||||
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
|
||||
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
|
||||
|
||||
## Common mistakes (add new rules below this line)
|
||||
39
.github/workflows/claude_review.yml
vendored
Normal file
39
.github/workflows/claude_review.yml
vendored
Normal file
@@ -0,0 +1,39 @@
|
||||
name: Claude PR Review
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
pull_request_review_comment:
|
||||
types: [created]
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
if: |
|
||||
(
|
||||
github.event_name == 'issue_comment' &&
|
||||
github.event.issue.pull_request &&
|
||||
github.event.issue.state == 'open' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
) || (
|
||||
github.event_name == 'pull_request_review_comment' &&
|
||||
contains(github.event.comment.body, '@claude') &&
|
||||
(github.event.comment.author_association == 'MEMBER' ||
|
||||
github.event.comment.author_association == 'OWNER' ||
|
||||
github.event.comment.author_association == 'COLLABORATOR')
|
||||
)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: anthropics/claude-code-action@v1
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
claude_args: |
|
||||
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."
|
||||
@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
|
||||
|
||||
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
|
||||
|
||||
## Kernels
|
||||
|
||||
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
|
||||
|
||||
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
|
||||
|
||||
> [!TIP]
|
||||
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
|
||||
|
||||
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Dynamic quantization
|
||||
|
||||
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.
|
||||
|
||||
@@ -1105,7 +1105,7 @@ def main(args):
|
||||
|
||||
# text encoding.
|
||||
captions = batch["captions"]
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
|
||||
captions, prompt_2=None
|
||||
|
||||
@@ -1251,7 +1251,7 @@ def main(args):
|
||||
|
||||
# text encoding.
|
||||
captions = batch["captions"]
|
||||
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
|
||||
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
|
||||
captions, prompt_2=None
|
||||
|
||||
250
examples/profiling/README.md
Normal file
250
examples/profiling/README.md
Normal file
@@ -0,0 +1,250 @@
|
||||
# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler
|
||||
|
||||
Education materials to strategically profile pipelines to potentially improve their
|
||||
runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`,
|
||||
we often have to get rid of DtoH syncs, CPU overheads, kernel launch delays, and
|
||||
graph breaks. In this context, profiling serves that purpose for us.
|
||||
|
||||
Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https://claude.com/contact-sales/claude-for-oss) support provided to us.
|
||||
|
||||
## Table of contents
|
||||
|
||||
* [Context](#context)
|
||||
* [Target pipelines](#target-pipelines)
|
||||
* [Approach taken](#approach)
|
||||
* [Verification](#verification)
|
||||
* [Interpretation](#interpreting-traces-in-perfetto-ui)
|
||||
* [Taking profiling-guided steps for improvements](#afterwards)
|
||||
|
||||
Jump to the "Verification" section to get started right away.
|
||||
|
||||
## Context
|
||||
|
||||
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
|
||||
|
||||
## Target Pipelines
|
||||
|
||||
| Pipeline | Type | Checkpoint | Steps |
|
||||
|----------|------|-----------|-------|
|
||||
| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 2 |
|
||||
| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 2 |
|
||||
| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 2 |
|
||||
| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 2 |
|
||||
| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 2 |
|
||||
|
||||
> [!NOTE]
|
||||
> We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This
|
||||
> include using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc.
|
||||
> But we keep the overall running time to a bare minimum (hence 2 `num_inference_steps`).
|
||||
|
||||
## Approach
|
||||
|
||||
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace.
|
||||
|
||||
### New Files
|
||||
|
||||
```
|
||||
profiling/
|
||||
profiling_utils.py # Annotation helper + profiler setup
|
||||
profiling_pipelines.py # CLI entry point with pipeline configs
|
||||
```
|
||||
|
||||
### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure
|
||||
|
||||
**A) `annotate(func, name)` helper** (same pattern as flux-fast):
|
||||
|
||||
```python
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
```
|
||||
|
||||
**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline:
|
||||
|
||||
- `pipe.transformer.forward` → `"transformer_forward"`
|
||||
- `pipe.vae.decode` → `"vae_decode"` (if present)
|
||||
- `pipe.vae.encode` → `"vae_encode"` (if present)
|
||||
- `pipe.scheduler.step` → `"scheduler_step"`
|
||||
- `pipe.encode_prompt` → `"encode_prompt"` (if present, for full-pipeline profiling)
|
||||
|
||||
This is non-invasive — it monkey-patches bound methods without modifying source.
|
||||
|
||||
**C) `PipelineProfiler` class:**
|
||||
|
||||
- `__init__(pipeline_config, output_dir, mode="eager"|"compile")`
|
||||
- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()`
|
||||
- `run()`:
|
||||
1. Warm up with 1 unannotated run
|
||||
2. Profile 1 run with `torch.profiler.profile`:
|
||||
- `activities=[CPU, CUDA]`
|
||||
- `record_shapes=True`
|
||||
- `profile_memory=True`
|
||||
- `with_stack=True`
|
||||
3. Export Chrome trace JSON
|
||||
4. Print `key_averages()` summary table (sorted by CUDA time) to stdout
|
||||
|
||||
### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs
|
||||
|
||||
**Pipeline config registry** — each entry specifies:
|
||||
|
||||
- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype`
|
||||
- `call_kwargs` with pipeline-specific defaults:
|
||||
|
||||
| Pipeline | Resolution | Frames | Steps | Extra |
|
||||
|----------|-----------|--------|-------|-------|
|
||||
| Flux | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Flux2Klein | 1024x1024 | — | 2 | `guidance_scale=3.5` |
|
||||
| Wan | 480x832 | 81 | 2 | — |
|
||||
| LTX2 | 768x512 | 121 | 2 | `guidance_scale=4.0` |
|
||||
| QwenImage | 1024x1024 | — | 2 | `true_cfg_scale=4.0` |
|
||||
|
||||
All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces).
|
||||
|
||||
**CLI flags:**
|
||||
|
||||
- `--pipeline flux|flux2|wan|ltx2|qwenimage|all`
|
||||
- `--mode eager|compile|both`
|
||||
- `--output_dir profiling_results/`
|
||||
- `--num_steps N` (override, default 4)
|
||||
- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE)
|
||||
- `--compile_mode default|reduce-overhead|max-autotune`
|
||||
- `--compile_fullgraph` flag
|
||||
|
||||
**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary.
|
||||
|
||||
### Step 3: Known Sync Issues to Validate
|
||||
|
||||
The profiling should surface these known/suspected issues:
|
||||
|
||||
1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines.
|
||||
|
||||
2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace.
|
||||
|
||||
3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces.
|
||||
|
||||
## Verification
|
||||
|
||||
1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
|
||||
2. Verify `profiling_results/flux_eager.json` is produced
|
||||
3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm:
|
||||
- `transformer_forward` and `scheduler_step` annotations visible
|
||||
- CPU and CUDA timelines present
|
||||
- Stack traces visible on CPU events
|
||||
4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels
|
||||
|
||||
You can also use the `run_profiling.sh` script to bulk launch runs for different pipelines.
|
||||
|
||||
## Interpreting Traces in Perfetto UI
|
||||
|
||||
Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows.
|
||||
|
||||
**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows.
|
||||
|
||||
### What to look for
|
||||
|
||||
**1. Gaps between CUDA kernels**
|
||||
|
||||
Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes:
|
||||
- Python overhead between ops (visible as CPU slices in the CPU row during the gap)
|
||||
- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed
|
||||
|
||||
**2. CPU stalls (DtoH syncs)**
|
||||
|
||||
These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler).
|
||||
|
||||
**3. Annotated regions**
|
||||
|
||||
Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly:
|
||||
- Measure how long each phase takes (click a span to see duration)
|
||||
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
|
||||
- Spot unexpected CPU work between annotated regions
|
||||
|
||||
**4. Eager vs compile comparison**
|
||||
|
||||
Open both traces side by side (two Perfetto tabs). Key differences to look for:
|
||||
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
|
||||
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
|
||||
- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks).
|
||||
- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details
|
||||
|
||||
**5. Memory timeline**
|
||||
|
||||
In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not.
|
||||
|
||||
**6. Kernel launch latency**
|
||||
|
||||
Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution:
|
||||
- The launch queue may be starved because of excessive Python work between ops
|
||||
- There may be implicit syncs forcing serialization
|
||||
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm
|
||||
|
||||
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
|
||||
|
||||
### Quick checklist per pipeline
|
||||
|
||||
| Question | Where to look | Healthy | Unhealthy |
|
||||
|----------|--------------|---------|-----------|
|
||||
| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us |
|
||||
| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step |
|
||||
| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time |
|
||||
| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager |
|
||||
| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU |
|
||||
| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step |
|
||||
|
||||
## Afterwards
|
||||
|
||||
To keep the profiling iterations fast, we always used [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). As one would expect the trace with compilation should show
|
||||
fewer kernel launches than its eager counterpart:
|
||||
|
||||
TODO: show traces
|
||||
|
||||
_(The traces above were obtained with Flux2.)_
|
||||
|
||||
### Spotting gaps between launches
|
||||
|
||||
Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled
|
||||
case, we don't spot any on the surface. But if we zone in, some become apparent.
|
||||
|
||||
TODO: show gaps in a compile trace
|
||||
|
||||
So, we provided the profile trace (with compilation) to Claude, asked it to find the instances of
|
||||
"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes.
|
||||
Claude came back pretty strong:
|
||||
|
||||
```
|
||||
Issue 1 — Gap between transformer forwards:
|
||||
- Root cause: tqdm progress bar update() calls between steps add CPU overhead (I/O, time calculations)
|
||||
- Fix: profiling/profiling_utils.py — added pipe.set_progress_bar_config(disable=True) during profiling setup.
|
||||
This eliminates the tqdm overhead from the trace. (The remaining gap from scheduler step + Python dispatch is
|
||||
inherent to eager-mode execution and should shrink significantly under torch.compile.)
|
||||
|
||||
Issue 2 — cudaStreamSynchronize during last transformer forward:
|
||||
- Root cause: _unpack_latents_with_ids() (called right after the denoising loop) computes h = torch.max(h_ids) +
|
||||
1 and w = torch.max(w_ids) + 1 on GPU tensors, then uses them as shape args for torch.zeros((h * w, ch), ...).
|
||||
This triggers an implicit .item() DtoH sync, blocking the CPU while the GPU is still finishing the last
|
||||
transformer forward's kernels.
|
||||
- Fix: Added height/width parameters to _unpack_latents_with_ids(), pre-computed from the known pixel dimensions
|
||||
at the call site.
|
||||
```
|
||||
|
||||
It still didn't eliminate the gaps as expected so, we fed that back to Claude and it spotted
|
||||
something more crucial. TODO: caching context fix.
|
||||
|
||||
With the fix applied, the improvements were visible:
|
||||
|
||||
TODO: show before and after trace
|
||||
|
||||
Before:
|
||||
|
||||
- `_set_context` total: 21.6ms (8 calls)
|
||||
- cache_context total: 21.7ms
|
||||
- CPU gaps: 5,523us / 8,007us / 5,508us
|
||||
|
||||
After:
|
||||
- `_set_context` total: 0.0ms (8 calls)
|
||||
- cache_context total: 0.1ms
|
||||
- CPU gaps: 158us / 2,777us / 136us
|
||||
182
examples/profiling/profiling_pipelines.py
Normal file
182
examples/profiling/profiling_pipelines.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Profile diffusers pipelines with torch.profiler.
|
||||
|
||||
Usage:
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode both
|
||||
python profiling/profiling_pipelines.py --pipeline all --mode eager
|
||||
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
|
||||
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT = "A cat holding a sign that says hello world"
|
||||
|
||||
|
||||
def build_registry():
|
||||
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
|
||||
from diffusers import FluxPipeline, Flux2KleinPipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline
|
||||
|
||||
return {
|
||||
"flux": PipelineProfilingConfig(
|
||||
name="flux",
|
||||
pipeline_cls=FluxPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"flux2": PipelineProfilingConfig(
|
||||
name="flux2",
|
||||
pipeline_cls=Flux2KleinPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 3.5,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"wan": PipelineProfilingConfig(
|
||||
name="wan",
|
||||
pipeline_cls=WanPipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
||||
"height": 480,
|
||||
"width": 832,
|
||||
"num_frames": 81,
|
||||
"num_inference_steps": 4,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"ltx2": PipelineProfilingConfig(
|
||||
name="ltx2",
|
||||
pipeline_cls=LTX2Pipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Lightricks/LTX-2",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
|
||||
"height": 512,
|
||||
"width": 768,
|
||||
"num_frames": 121,
|
||||
"num_inference_steps": 4,
|
||||
"guidance_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
"qwenimage": PipelineProfilingConfig(
|
||||
name="qwenimage",
|
||||
pipeline_cls=QwenImagePipeline,
|
||||
pipeline_init_kwargs={
|
||||
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
|
||||
"torch_dtype": torch.bfloat16,
|
||||
},
|
||||
pipeline_call_kwargs={
|
||||
"prompt": PROMPT,
|
||||
"negative_prompt": " ",
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"num_inference_steps": 4,
|
||||
"true_cfg_scale": 4.0,
|
||||
"output_type": "latent",
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
|
||||
required=True,
|
||||
help="Which pipeline to profile",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["eager", "compile", "both"],
|
||||
default="eager",
|
||||
help="Run in eager mode, compile mode, or both",
|
||||
)
|
||||
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
|
||||
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
|
||||
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
|
||||
parser.add_argument(
|
||||
"--compile_mode",
|
||||
default="default",
|
||||
choices=["default", "reduce-overhead", "max-autotune"],
|
||||
help="torch.compile mode",
|
||||
)
|
||||
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
|
||||
parser.add_argument(
|
||||
"--compile_regional",
|
||||
action="store_true",
|
||||
help="Use compile_repeated_blocks() instead of full model compile",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
registry = build_registry()
|
||||
|
||||
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
|
||||
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
|
||||
|
||||
for pipeline_name in pipeline_names:
|
||||
for mode in modes:
|
||||
config = copy.deepcopy(registry[pipeline_name])
|
||||
|
||||
# Apply overrides
|
||||
if args.num_steps is not None:
|
||||
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
|
||||
if args.full_decode:
|
||||
config.pipeline_call_kwargs["output_type"] = "pil"
|
||||
if mode == "compile":
|
||||
config.compile_kwargs = {
|
||||
"fullgraph": args.compile_fullgraph,
|
||||
"mode": args.compile_mode,
|
||||
}
|
||||
config.compile_regional = args.compile_regional
|
||||
|
||||
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
|
||||
profiler = PipelineProfiler(config, args.output_dir)
|
||||
try:
|
||||
trace_file = profiler.run()
|
||||
logger.info(f"Done: {trace_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
146
examples/profiling/profiling_utils.py
Normal file
146
examples/profiling/profiling_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import functools
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.profiler
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def annotate(func, name):
|
||||
"""Wrap a function with torch.profiler.record_function for trace annotation."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with torch.profiler.record_function(name):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_pipeline(pipe):
|
||||
"""Apply profiler annotations to key pipeline methods.
|
||||
|
||||
Monkey-patches bound methods so they appear as named spans in the trace.
|
||||
Non-invasive — no source modifications required.
|
||||
"""
|
||||
annotations = [
|
||||
("transformer", "forward", "transformer_forward"),
|
||||
("vae", "decode", "vae_decode"),
|
||||
("vae", "encode", "vae_encode"),
|
||||
("scheduler", "step", "scheduler_step"),
|
||||
]
|
||||
|
||||
# Annotate sub-component methods
|
||||
for component_name, method_name, label in annotations:
|
||||
component = getattr(pipe, component_name, None)
|
||||
if component is None:
|
||||
continue
|
||||
method = getattr(component, method_name, None)
|
||||
if method is None:
|
||||
continue
|
||||
setattr(component, method_name, annotate(method, label))
|
||||
|
||||
# Annotate pipeline-level methods
|
||||
if hasattr(pipe, "encode_prompt"):
|
||||
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
|
||||
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineProfilingConfig:
|
||||
name: str
|
||||
pipeline_cls: Any
|
||||
pipeline_init_kwargs: dict[str, Any]
|
||||
pipeline_call_kwargs: dict[str, Any]
|
||||
compile_kwargs: dict[str, Any] | None = field(default=None)
|
||||
compile_regional: bool = False
|
||||
|
||||
|
||||
class PipelineProfiler:
|
||||
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
|
||||
self.config = config
|
||||
self.output_dir = output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
def setup_pipeline(self):
|
||||
"""Load the pipeline from pretrained, optionally compile, and annotate."""
|
||||
logger.info(f"Loading pipeline: {self.config.name}")
|
||||
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
|
||||
pipe.to("cuda")
|
||||
|
||||
if self.config.compile_kwargs:
|
||||
if self.config.compile_regional:
|
||||
logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}")
|
||||
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
|
||||
else:
|
||||
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
|
||||
pipe.transformer.compile(**self.config.compile_kwargs)
|
||||
|
||||
# Disable tqdm progress bar to avoid CPU overhead / IO between steps
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
annotate_pipeline(pipe)
|
||||
return pipe
|
||||
|
||||
def run(self):
|
||||
"""Execute the profiling run: warmup, then profile one pipeline call."""
|
||||
pipe = self.setup_pipeline()
|
||||
flush()
|
||||
|
||||
mode = "compile" if self.config.compile_kwargs else "eager"
|
||||
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
|
||||
|
||||
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
|
||||
logger.info("Running warmup...")
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
flush()
|
||||
|
||||
# Profile
|
||||
logger.info("Running profiled iteration...")
|
||||
activities = [
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]
|
||||
with torch.profiler.profile(
|
||||
activities=activities,
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
with torch.profiler.record_function("pipeline_call"):
|
||||
pipe(**self.config.pipeline_call_kwargs)
|
||||
|
||||
# Export trace
|
||||
prof.export_chrome_trace(trace_file)
|
||||
logger.info(f"Chrome trace saved to: {trace_file}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 80)
|
||||
print(f"Profile summary: {self.config.name} ({mode})")
|
||||
print("=" * 80)
|
||||
print(
|
||||
prof.key_averages().table(
|
||||
sort_by="cuda_time_total",
|
||||
row_limit=20,
|
||||
)
|
||||
)
|
||||
|
||||
# Cleanup
|
||||
pipe.to("cpu")
|
||||
del pipe
|
||||
flush()
|
||||
|
||||
return trace_file
|
||||
46
examples/profiling/run_profiling.sh
Executable file
46
examples/profiling/run_profiling.sh
Executable file
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
# Run profiling across all pipelines in eager and compile (regional) modes.
|
||||
#
|
||||
# Usage:
|
||||
# bash profiling/run_profiling.sh
|
||||
# bash profiling/run_profiling.sh --output_dir my_results
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
OUTPUT_DIR="profiling_results"
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
|
||||
*) echo "Unknown arg: $1"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
NUM_STEPS=2
|
||||
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
|
||||
PIPELINES=("flux2")
|
||||
MODES=("eager" "compile")
|
||||
|
||||
for pipeline in "${PIPELINES[@]}"; do
|
||||
for mode in "${MODES[@]}"; do
|
||||
echo "============================================================"
|
||||
echo "Profiling: ${pipeline} | mode: ${mode}"
|
||||
echo "============================================================"
|
||||
|
||||
COMPILE_ARGS=""
|
||||
if [ "$mode" = "compile" ]; then
|
||||
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
|
||||
fi
|
||||
|
||||
python profiling/profiling_pipelines.py \
|
||||
--pipeline "$pipeline" \
|
||||
--mode "$mode" \
|
||||
--output_dir "$OUTPUT_DIR" \
|
||||
--num_steps "$NUM_STEPS" \
|
||||
$COMPILE_ARGS
|
||||
|
||||
echo ""
|
||||
done
|
||||
done
|
||||
|
||||
echo "============================================================"
|
||||
echo "All traces saved to: ${OUTPUT_DIR}/"
|
||||
echo "============================================================"
|
||||
@@ -271,12 +271,31 @@ class HookRegistry:
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for registry in self._get_child_registries():
|
||||
registry._set_context(name)
|
||||
|
||||
def _get_child_registries(self) -> list["HookRegistry"]:
|
||||
"""Return registries of child modules, using a cached list when available.
|
||||
|
||||
The cache is built on first call and reused for subsequent calls. This avoids the cost of walking the full
|
||||
module tree via named_modules() on every _set_context call, which is significant for large models (e.g. ~2.7ms
|
||||
per call on Flux2).
|
||||
"""
|
||||
if not hasattr(self, "_child_registries_cache"):
|
||||
self._child_registries_cache = None
|
||||
|
||||
if self._child_registries_cache is not None:
|
||||
return self._child_registries_cache
|
||||
|
||||
registries = []
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
registries.append(module._diffusers_hook)
|
||||
self._child_registries_cache = registries
|
||||
return registries
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
|
||||
@@ -397,7 +397,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
def _unpack_latents_with_ids(
|
||||
x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
@@ -407,8 +409,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
# Use provided height/width to avoid DtoH sync from torch.max().item()
|
||||
h = height if height is not None else torch.max(h_ids) + 1
|
||||
w = width if width is not None else torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
@@ -895,7 +898,10 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
# Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item()
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
|
||||
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
242
tests/modular_pipelines/test_conditional_pipeline_blocks.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
AutoPipelineBlocks,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
|
||||
|
||||
class TextToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "text2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "text-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "text2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ImageToImageBlock(ModularPipelineBlocks):
|
||||
model_name = "img2img"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "image-to-image workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "img2img"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class InpaintBlock(ModularPipelineBlocks):
|
||||
model_name = "inpaint"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "inpaint workflow"
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.workflow = "inpaint"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class ConditionalImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Conditional image blocks for testing"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name
|
||||
|
||||
|
||||
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = None # no default; block can be skipped
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Optional conditional blocks (skippable)"
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None
|
||||
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto image blocks for testing"
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksSelectBlock:
|
||||
def test_select_block_with_mask(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="something") == "inpaint"
|
||||
|
||||
def test_select_block_with_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(image="something") == "img2img"
|
||||
|
||||
def test_select_block_with_mask_and_image(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
def test_select_block_no_triggers_returns_none(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_select_block_explicit_none_values(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert blocks.select_block(mask=None, image=None) is None
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksWorkflowSelection:
|
||||
def test_default_workflow_when_no_triggers(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is not None
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_mask_trigger_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_image_trigger_selects_img2img(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
def test_mask_and_image_selects_inpaint(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True, image=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_skippable_block_returns_none(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert execution is None
|
||||
|
||||
def test_skippable_block_still_selects_when_triggered(self):
|
||||
blocks = OptionalConditionalBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksSelectBlock:
|
||||
def test_auto_select_mask(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m") == "inpaint"
|
||||
|
||||
def test_auto_select_image(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(image="i") == "img2img"
|
||||
|
||||
def test_auto_select_default(self):
|
||||
blocks = AutoImageBlocks()
|
||||
# No trigger -> returns None -> falls back to default (text2img)
|
||||
assert blocks.select_block() is None
|
||||
|
||||
def test_auto_select_priority_order(self):
|
||||
blocks = AutoImageBlocks()
|
||||
assert blocks.select_block(mask="m", image="i") == "inpaint"
|
||||
|
||||
|
||||
class TestAutoPipelineBlocksWorkflowSelection:
|
||||
def test_auto_default_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks()
|
||||
assert isinstance(execution, TextToImageBlock)
|
||||
|
||||
def test_auto_mask_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(mask=True)
|
||||
assert isinstance(execution, InpaintBlock)
|
||||
|
||||
def test_auto_image_workflow(self):
|
||||
blocks = AutoImageBlocks()
|
||||
execution = blocks.get_execution_blocks(image=True)
|
||||
assert isinstance(execution, ImageToImageBlock)
|
||||
|
||||
|
||||
class TestConditionalPipelineBlocksStructure:
|
||||
def test_block_names_accessible(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
|
||||
|
||||
def test_sub_block_types(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
sub = dict(blocks.sub_blocks)
|
||||
assert isinstance(sub["inpaint"], InpaintBlock)
|
||||
assert isinstance(sub["img2img"], ImageToImageBlock)
|
||||
assert isinstance(sub["text2img"], TextToImageBlock)
|
||||
|
||||
def test_description(self):
|
||||
blocks = ConditionalImageBlocks()
|
||||
assert "Conditional" in blocks.description
|
||||
@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
from diffusers.modular_pipelines import (
|
||||
ConditionalPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks,
|
||||
SequentialPipelineBlocks,
|
||||
)
|
||||
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import (
|
||||
CaptureLogger,
|
||||
backend_empty_cache,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_accelerator,
|
||||
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
|
||||
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
class TestModularModelCardContent:
|
||||
def create_mock_block(self, name="TestBlock", description="Test block description"):
|
||||
class MockBlock:
|
||||
|
||||
@@ -24,14 +24,18 @@ import torch
|
||||
from diffusers import FluxTransformer2DModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ConditionalPipelineBlocks,
|
||||
InputParam,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
OutputParam,
|
||||
PipelineState,
|
||||
SequentialPipelineBlocks,
|
||||
WanModularPipeline,
|
||||
)
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
|
||||
assert output_prompt.startswith("Modular diffusers + ")
|
||||
|
||||
|
||||
class TestCustomBlockRequirements:
|
||||
def get_dummy_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
# keep two arbitrary deps so that we can test warnings.
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
# keep two dependencies that will be available during testing.
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict(
|
||||
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
|
||||
)
|
||||
return pipe
|
||||
|
||||
def get_dummy_conditional_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
class DummyConditionalBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [DummyBlockOne, DummyBlockTwo]
|
||||
block_names = ["block_one", "block_two"]
|
||||
block_trigger_inputs = []
|
||||
|
||||
def select_block(self, **kwargs):
|
||||
return "block_one"
|
||||
|
||||
return DummyConditionalBlocks()
|
||||
|
||||
def get_dummy_loop_block_pipe(self):
|
||||
class DummyBlockOne:
|
||||
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
|
||||
|
||||
class DummyBlockTwo:
|
||||
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
|
||||
|
||||
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
|
||||
|
||||
def test_sequential_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
requirements = config["requirements"]
|
||||
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == requirements
|
||||
|
||||
def test_sequential_block_requirements_warnings(self, tmp_path):
|
||||
pipe = self.get_dummy_block_pipe()
|
||||
|
||||
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
|
||||
logger.setLevel(30)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
template = "{req} was specified in the requirements but wasn't found in the current environment"
|
||||
msg_xyz = template.format(req="xyz")
|
||||
msg_abc = template.format(req="abc")
|
||||
assert msg_xyz in str(cap_logger.out)
|
||||
assert msg_abc in str(cap_logger.out)
|
||||
|
||||
def test_conditional_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_conditional_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
def test_loop_block_requirements_save_load(self, tmp_path):
|
||||
pipe = self.get_dummy_loop_block_pipe()
|
||||
pipe.save_pretrained(str(tmp_path))
|
||||
|
||||
config_path = tmp_path / "modular_config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
assert "requirements" in config
|
||||
expected_requirements = {
|
||||
"xyz": ">=0.8.0",
|
||||
"abc": ">=10.0.0",
|
||||
"transformers": ">=4.44.0",
|
||||
"diffusers": ">=0.2.0",
|
||||
}
|
||||
assert expected_requirements == config["requirements"]
|
||||
|
||||
|
||||
@slow
|
||||
@nightly
|
||||
@require_torch
|
||||
|
||||
Reference in New Issue
Block a user