mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-26 09:51:34 +08:00
* add a profiling worflow. * fix * fix * more clarification * add points. * up * cache hooks * improve readme. * propagate deletion. * up * up * wan fixes. * more * up * add more traces. * up * better title * cuda graphs. * up * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * add torch.compile link. * approach -> How the tooling works * table * unavoidable gaps. * make important * note on regional compilation * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make regional compilation note clearer. * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * clarify scheduler related changes. * Apply suggestions from code review Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update examples/profiling/README.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * up * formatting * benchmarking runtime * up * up * up * up * Update examples/profiling/README.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
197 lines
7.8 KiB
Python
197 lines
7.8 KiB
Python
"""
|
|
Profile diffusers pipelines with torch.profiler.
|
|
|
|
Usage:
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode eager
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode compile
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode both
|
|
python profiling/profiling_pipelines.py --pipeline all --mode eager
|
|
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
|
|
|
|
Benchmarking (wall-clock time, no profiler overhead):
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode compile --benchmark
|
|
python profiling/profiling_pipelines.py --pipeline flux --mode both --benchmark --num_runs 10 --num_warmups 3
|
|
"""
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
|
|
import torch
|
|
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PROMPT = "A cat holding a sign that says hello world"
|
|
|
|
|
|
def build_registry():
|
|
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
|
|
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline
|
|
|
|
return {
|
|
"flux": PipelineProfilingConfig(
|
|
name="flux",
|
|
pipeline_cls=FluxPipeline,
|
|
pipeline_init_kwargs={
|
|
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
pipeline_call_kwargs={
|
|
"prompt": PROMPT,
|
|
"height": 1024,
|
|
"width": 1024,
|
|
"num_inference_steps": 4,
|
|
"guidance_scale": 3.5,
|
|
"output_type": "latent",
|
|
},
|
|
),
|
|
"flux2": PipelineProfilingConfig(
|
|
name="flux2",
|
|
pipeline_cls=Flux2KleinPipeline,
|
|
pipeline_init_kwargs={
|
|
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
pipeline_call_kwargs={
|
|
"prompt": PROMPT,
|
|
"height": 1024,
|
|
"width": 1024,
|
|
"num_inference_steps": 4,
|
|
"guidance_scale": 3.5,
|
|
"output_type": "latent",
|
|
},
|
|
),
|
|
"wan": PipelineProfilingConfig(
|
|
name="wan",
|
|
pipeline_cls=WanPipeline,
|
|
pipeline_init_kwargs={
|
|
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
pipeline_call_kwargs={
|
|
"prompt": PROMPT,
|
|
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
|
|
"height": 480,
|
|
"width": 832,
|
|
"num_frames": 81,
|
|
"num_inference_steps": 4,
|
|
"output_type": "latent",
|
|
},
|
|
),
|
|
"ltx2": PipelineProfilingConfig(
|
|
name="ltx2",
|
|
pipeline_cls=LTX2Pipeline,
|
|
pipeline_init_kwargs={
|
|
"pretrained_model_name_or_path": "Lightricks/LTX-2",
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
pipeline_call_kwargs={
|
|
"prompt": PROMPT,
|
|
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
|
|
"height": 512,
|
|
"width": 768,
|
|
"num_frames": 121,
|
|
"num_inference_steps": 4,
|
|
"guidance_scale": 4.0,
|
|
"output_type": "latent",
|
|
},
|
|
),
|
|
"qwenimage": PipelineProfilingConfig(
|
|
name="qwenimage",
|
|
pipeline_cls=QwenImagePipeline,
|
|
pipeline_init_kwargs={
|
|
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
|
|
"torch_dtype": torch.bfloat16,
|
|
},
|
|
pipeline_call_kwargs={
|
|
"prompt": PROMPT,
|
|
"negative_prompt": " ",
|
|
"height": 1024,
|
|
"width": 1024,
|
|
"num_inference_steps": 4,
|
|
"true_cfg_scale": 4.0,
|
|
"output_type": "latent",
|
|
},
|
|
),
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
|
|
parser.add_argument(
|
|
"--pipeline",
|
|
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
|
|
required=True,
|
|
help="Which pipeline to profile",
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
choices=["eager", "compile", "both"],
|
|
default="eager",
|
|
help="Run in eager mode, compile mode, or both",
|
|
)
|
|
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
|
|
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
|
|
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
|
|
parser.add_argument(
|
|
"--compile_mode",
|
|
default="default",
|
|
choices=["default", "reduce-overhead", "max-autotune"],
|
|
help="torch.compile mode",
|
|
)
|
|
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
|
|
parser.add_argument(
|
|
"--compile_regional",
|
|
action="store_true",
|
|
help="Use compile_repeated_blocks() instead of full model compile",
|
|
)
|
|
parser.add_argument(
|
|
"--benchmark",
|
|
action="store_true",
|
|
help="Benchmark wall-clock time instead of profiling. Uses CUDA events, no profiler overhead.",
|
|
)
|
|
parser.add_argument("--num_runs", type=int, default=5, help="Number of timed runs for benchmarking")
|
|
parser.add_argument("--num_warmups", type=int, default=2, help="Number of warmup runs for benchmarking")
|
|
args = parser.parse_args()
|
|
|
|
registry = build_registry()
|
|
|
|
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
|
|
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
|
|
|
|
for pipeline_name in pipeline_names:
|
|
for mode in modes:
|
|
config = copy.deepcopy(registry[pipeline_name])
|
|
|
|
# Apply overrides
|
|
if args.num_steps is not None:
|
|
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
|
|
if args.full_decode:
|
|
config.pipeline_call_kwargs["output_type"] = "pil"
|
|
if mode == "compile":
|
|
config.compile_kwargs = {
|
|
"fullgraph": args.compile_fullgraph,
|
|
"mode": args.compile_mode,
|
|
}
|
|
config.compile_regional = args.compile_regional
|
|
|
|
profiler = PipelineProfiler(config, args.output_dir)
|
|
try:
|
|
if args.benchmark:
|
|
logger.info(f"Benchmarking {pipeline_name} in {mode} mode...")
|
|
profiler.benchmark(num_runs=args.num_runs, num_warmups=args.num_warmups)
|
|
else:
|
|
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
|
|
trace_file = profiler.run()
|
|
logger.info(f"Done: {trace_file}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to {'benchmark' if args.benchmark else 'profile'} {pipeline_name} ({mode}): {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|