Files
diffusers/examples/profiling/profiling_utils.py
Sayak Paul b114620d85 Add examples on how to profile a pipeline (#13356)
* add a profiling worflow.

* fix

* fix

* more clarification

* add points.

* up

* cache hooks

* improve readme.

* propagate deletion.

* up

* up

* wan fixes.

* more

* up

* add more traces.

* up

* better title

* cuda graphs.

* up

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* add torch.compile link.

* approach -> How the tooling works

* table

* unavoidable gaps.

* make important

* note on regional compilation

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* make regional compilation note clearer.

* Apply suggestions from code review

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* clarify scheduler related changes.

* Apply suggestions from code review

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update examples/profiling/README.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* up

* formatting

* benchmarking runtime

* up

* up

* up

* up

* Update examples/profiling/README.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-04-03 16:13:01 +02:00

216 lines
6.8 KiB
Python

import functools
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.profiler
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)
return wrapper
def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.
Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
("vae", "decode", "vae_decode"),
("vae", "encode", "vae_encode"),
("scheduler", "step", "scheduler_step"),
]
# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
if component is None:
continue
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))
# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def benchmark_fn(f, *args, num_runs=5, num_warmups=2, **kwargs):
"""Benchmark a function using CUDA events for accurate GPU timing.
Uses CUDA events to measure wall-clock time including GPU execution,
without the overhead of torch.profiler. Reports mean and standard deviation
over multiple runs.
Returns:
dict with keys: mean_ms, std_ms, runs_ms (list of individual timings)
"""
# Warmup
for _ in range(num_warmups):
f(*args, **kwargs)
torch.cuda.synchronize()
# Timed runs
times = []
for _ in range(num_runs):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
f(*args, **kwargs)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
mean_ms = sum(times) / len(times)
variance = sum((t - mean_ms) ** 2 for t in times) / len(times)
std_ms = variance**0.5
return {"mean_ms": mean_ms, "std_ms": std_ms, "runs_ms": times}
@dataclass
class PipelineProfilingConfig:
name: str
pipeline_cls: Any
pipeline_init_kwargs: dict[str, Any]
pipeline_call_kwargs: dict[str, Any]
compile_kwargs: dict[str, Any] | None = field(default=None)
compile_regional: bool = False
class PipelineProfiler:
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
self.config = config
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
def setup_pipeline(self, annotate=True):
"""Load the pipeline from pretrained, optionally compile, and annotate."""
logger.info(f"Loading pipeline: {self.config.name}")
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
pipe.to("cuda")
if self.config.compile_kwargs:
if self.config.compile_regional:
logger.info(
f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}"
)
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
else:
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile(**self.config.compile_kwargs)
# Disable tqdm progress bar to avoid CPU overhead / IO between steps
pipe.set_progress_bar_config(disable=True)
if annotate:
annotate_pipeline(pipe)
return pipe
def run(self):
"""Execute the profiling run: warmup, then profile one pipeline call."""
pipe = self.setup_pipeline()
flush()
mode = "compile" if self.config.compile_kwargs else "eager"
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
logger.info("Running warmup...")
pipe(**self.config.pipeline_call_kwargs)
flush()
# Profile
logger.info("Running profiled iteration...")
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("pipeline_call"):
pipe(**self.config.pipeline_call_kwargs)
# Export trace
prof.export_chrome_trace(trace_file)
logger.info(f"Chrome trace saved to: {trace_file}")
# Print summary
print("\n" + "=" * 80)
print(f"Profile summary: {self.config.name} ({mode})")
print("=" * 80)
print(
prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
)
)
# Cleanup
pipe.to("cpu")
del pipe
flush()
return trace_file
def benchmark(self, num_runs=5, num_warmups=2):
"""Benchmark pipeline wall-clock time without profiler overhead.
Uses CUDA events for accurate GPU-inclusive timing over multiple runs.
No annotations are applied to avoid any overhead from record_function wrappers.
Reports mean, std, and individual run times.
"""
pipe = self.setup_pipeline(annotate=False)
flush()
mode = "compile" if self.config.compile_kwargs else "eager"
logger.info(f"Benchmarking {self.config.name} ({mode}): {num_warmups} warmup + {num_runs} timed runs...")
result = benchmark_fn(pipe, num_runs=num_runs, num_warmups=num_warmups, **self.config.pipeline_call_kwargs)
print("\n" + "=" * 80)
print(f"Benchmark: {self.config.name} ({mode})")
print("=" * 80)
print(f" Runs: {num_runs} (after {num_warmups} warmup)")
print(f" Mean: {result['mean_ms']:.1f} ms")
print(f" Std: {result['std_ms']:.1f} ms")
print(f" Individual: {', '.join(f'{t:.1f}' for t in result['runs_ms'])} ms")
print("=" * 80)
# Cleanup
pipe.to("cpu")
del pipe
flush()
return result