mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
* start overhauling the benchmarking suite. * fixes * fixes * checking. * checking * fixes. * error handling and logging. * add flops and params. * add more models. * utility to fire execution of all benchmarking scripts. * utility to push to the hub. * push utility improvement * seems to be working. * okay * add torchprofile dep. * remove total gpu memory * fixes * fix * need a big gpu * better * what's happening. * okay * separate requirements and make it nightly. * add db population script. * update secret name * update secret. * population db update * disable db population for now. * change to every monday * Update .github/workflows/benchmark.yml Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * quality improvements. * reparate hub upload step. * repository * remove csv * check * update * update * threading. * update * update * updaye * update * update * update * remove peft dep * upgrade runner. * fix * fixes * fix merging csvs. * push dataset to the Space repo for analysis. * warm up. * add a readme * Apply suggestions from code review Co-authored-by: Luc Georges <McPatate@users.noreply.github.com> * address feedback * Apply suggestions from code review * disable db workflow. * update to bi weekly. * enable population * enable * updaye * update * metadata * fix --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: Luc Georges <McPatate@users.noreply.github.com>
245 lines
7.9 KiB
Python
245 lines
7.9 KiB
Python
import gc
|
|
import inspect
|
|
import logging
|
|
import os
|
|
import queue
|
|
import threading
|
|
from contextlib import nullcontext
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, Optional, Union
|
|
|
|
import pandas as pd
|
|
import torch
|
|
import torch.utils.benchmark as benchmark
|
|
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NUM_WARMUP_ROUNDS = 5
|
|
|
|
|
|
def benchmark_fn(f, *args, **kwargs):
|
|
t0 = benchmark.Timer(
|
|
stmt="f(*args, **kwargs)",
|
|
globals={"args": args, "kwargs": kwargs, "f": f},
|
|
num_threads=1,
|
|
)
|
|
return float(f"{(t0.blocked_autorange().mean):.3f}")
|
|
|
|
|
|
def flush():
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_max_memory_allocated()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
|
|
# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
|
|
def calculate_flops(model, input_dict):
|
|
try:
|
|
from torchprofile import profile_macs
|
|
except ModuleNotFoundError:
|
|
raise
|
|
|
|
# This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
|
|
sig = inspect.signature(model.forward)
|
|
param_names = [
|
|
p.name
|
|
for p in sig.parameters.values()
|
|
if p.kind
|
|
in (
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
)
|
|
and p.name != "self"
|
|
]
|
|
bound = sig.bind_partial(**input_dict)
|
|
bound.apply_defaults()
|
|
args = tuple(bound.arguments[name] for name in param_names)
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
macs = profile_macs(model, args)
|
|
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
|
|
return flops
|
|
|
|
|
|
def calculate_params(model):
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
# Users can define their own in case this doesn't suffice. For most cases,
|
|
# it should be sufficient.
|
|
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
|
|
model = model_cls.from_pretrained(**init_kwargs).eval()
|
|
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
|
|
model.enable_group_offload(**group_offload_kwargs)
|
|
else:
|
|
model.to(torch_device)
|
|
if layerwise_upcasting:
|
|
model.enable_layerwise_casting(
|
|
storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
|
|
)
|
|
return model
|
|
|
|
|
|
@dataclass
|
|
class BenchmarkScenario:
|
|
name: str
|
|
model_cls: ModelMixin
|
|
model_init_kwargs: Dict[str, Any]
|
|
model_init_fn: Callable
|
|
get_model_input_dict: Callable
|
|
compile_kwargs: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
@require_torch_gpu
|
|
class BenchmarkMixin:
|
|
def pre_benchmark(self):
|
|
flush()
|
|
torch.compiler.reset()
|
|
|
|
def post_benchmark(self, model):
|
|
model.cpu()
|
|
flush()
|
|
torch.compiler.reset()
|
|
|
|
@torch.no_grad()
|
|
def run_benchmark(self, scenario: BenchmarkScenario):
|
|
# 0) Basic stats
|
|
logger.info(f"Running scenario: {scenario.name}.")
|
|
try:
|
|
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
|
|
num_params = round(calculate_params(model) / 1e9, 2)
|
|
try:
|
|
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
|
|
except Exception as e:
|
|
logger.info(f"Problem in calculating FLOPs:\n{e}")
|
|
flops = None
|
|
model.cpu()
|
|
del model
|
|
except Exception as e:
|
|
logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
|
|
return {}
|
|
self.pre_benchmark()
|
|
|
|
# 1) plain stats
|
|
results = {}
|
|
plain = None
|
|
try:
|
|
plain = self._run_phase(
|
|
model_cls=scenario.model_cls,
|
|
init_fn=scenario.model_init_fn,
|
|
init_kwargs=scenario.model_init_kwargs,
|
|
get_input_fn=scenario.get_model_input_dict,
|
|
compile_kwargs=None,
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Benchmark could not be run with the following error:\n{e}")
|
|
return results
|
|
|
|
# 2) compiled stats (if any)
|
|
compiled = {"time": None, "memory": None}
|
|
if scenario.compile_kwargs:
|
|
try:
|
|
compiled = self._run_phase(
|
|
model_cls=scenario.model_cls,
|
|
init_fn=scenario.model_init_fn,
|
|
init_kwargs=scenario.model_init_kwargs,
|
|
get_input_fn=scenario.get_model_input_dict,
|
|
compile_kwargs=scenario.compile_kwargs,
|
|
)
|
|
except Exception as e:
|
|
logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
|
|
if plain is None:
|
|
return results
|
|
|
|
# 3) merge
|
|
result = {
|
|
"scenario": scenario.name,
|
|
"model_cls": scenario.model_cls.__name__,
|
|
"num_params_B": num_params,
|
|
"flops_G": flops,
|
|
"time_plain_s": plain["time"],
|
|
"mem_plain_GB": plain["memory"],
|
|
"time_compile_s": compiled["time"],
|
|
"mem_compile_GB": compiled["memory"],
|
|
}
|
|
if scenario.compile_kwargs:
|
|
result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
|
|
result["mode"] = scenario.compile_kwargs.get("mode", "default")
|
|
else:
|
|
result["fullgraph"], result["mode"] = None, None
|
|
return result
|
|
|
|
def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
|
|
if not isinstance(scenarios, list):
|
|
scenarios = [scenarios]
|
|
record_queue = queue.Queue()
|
|
stop_signal = object()
|
|
|
|
def _writer_thread():
|
|
while True:
|
|
item = record_queue.get()
|
|
if item is stop_signal:
|
|
break
|
|
df_row = pd.DataFrame([item])
|
|
write_header = not os.path.exists(filename)
|
|
df_row.to_csv(filename, mode="a", header=write_header, index=False)
|
|
record_queue.task_done()
|
|
|
|
record_queue.task_done()
|
|
|
|
writer = threading.Thread(target=_writer_thread, daemon=True)
|
|
writer.start()
|
|
|
|
for s in scenarios:
|
|
try:
|
|
record = self.run_benchmark(s)
|
|
if record:
|
|
record_queue.put(record)
|
|
else:
|
|
logger.info(f"Record empty from scenario: {s.name}.")
|
|
except Exception as e:
|
|
logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
|
|
record_queue.put(stop_signal)
|
|
logger.info(f"Results serialized to {filename=}.")
|
|
|
|
def _run_phase(
|
|
self,
|
|
*,
|
|
model_cls: ModelMixin,
|
|
init_fn: Callable,
|
|
init_kwargs: Dict[str, Any],
|
|
get_input_fn: Callable,
|
|
compile_kwargs: Optional[Dict[str, Any]],
|
|
) -> Dict[str, float]:
|
|
# setup
|
|
self.pre_benchmark()
|
|
|
|
# init & (optional) compile
|
|
model = init_fn(model_cls, **init_kwargs)
|
|
if compile_kwargs:
|
|
model.compile(**compile_kwargs)
|
|
|
|
# build inputs
|
|
inp = get_input_fn()
|
|
|
|
# measure
|
|
run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
|
|
with run_ctx:
|
|
for _ in range(NUM_WARMUP_ROUNDS):
|
|
_ = model(**inp)
|
|
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
|
|
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
|
mem_gb = round(mem_gb, 2)
|
|
|
|
# teardown
|
|
self.post_benchmark(model)
|
|
del model
|
|
return {"time": time_s, "memory": mem_gb}
|