Compare commits

...

3 Commits

Author SHA1 Message Date
Charles
6c12a205a0 Merge branch 'main' into version-checks-cache 2025-10-06 15:47:16 +02:00
Charles
39216fc91c lru_cache for Python 3.8 2025-09-26 17:42:01 +02:00
Charles
2ca3cadb35 [perf] Cache version checks
I recently noticed that we are spending a non-negligible amount of time in `version.parse` when running pipelines (approx. ~50ms per step for the QwenImageEdit pipeline on a ZeroGPU Space for instance, which in this case represents almost 10% of the actual compute). The calls to those version checks originate from:
- 4588bbeb42/src/diffusers/hooks/hooks.py (L277)

Maybe that the issue can otherwise be solved from root (why do we need to unwrap the modules at each call?) or maybe that my particular setup triggered this? (I patched the forward method at the blocks level but I don't feel like it has an incidence over _set_context)
2025-09-26 17:28:55 +02:00

View File

@@ -21,6 +21,7 @@ import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.