Compare commits

...

3 Commits

Author SHA1 Message Date
Charles
6c12a205a0 Merge branch 'main' into version-checks-cache 2025-10-06 15:47:16 +02:00
Charles
39216fc91c lru_cache for Python 3.8 2025-09-26 17:42:01 +02:00
Charles
2ca3cadb35 [perf] Cache version checks
I recently noticed that we are spending a non-negligible amount of time in `version.parse` when running pipelines (approx. ~50ms per step for the QwenImageEdit pipeline on a ZeroGPU Space for instance, which in this case represents almost 10% of the actual compute). The calls to those version checks originate from:
- 4588bbeb42/src/diffusers/hooks/hooks.py (L277)

Maybe that the issue can otherwise be solved from root (why do we need to unwrap the modules at each call?) or maybe that my particular setup triggered this? (I patched the forward method at the blocks level but I don't feel like it has an incidence over _set_context)
2025-09-26 17:28:55 +02:00

View File

@@ -21,6 +21,7 @@ import operator as op
import os import os
import sys import sys
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain from itertools import chain
from types import ModuleType from types import ModuleType
from typing import Any, Tuple, Union from typing import Any, Tuple, Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str): def is_torch_version(operation: str, version: str):
""" """
Compares the current PyTorch version to a given reference with an operation. Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version) return compare_versions(parse(_torch_version), operation, version)
@cache
def is_torch_xla_version(operation: str, version: str): def is_torch_xla_version(operation: str, version: str):
""" """
Compares the current torch_xla version to a given reference with an operation. Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version) return compare_versions(parse(_torch_xla_version), operation, version)
@cache
def is_transformers_version(operation: str, version: str): def is_transformers_version(operation: str, version: str):
""" """
Compares the current Transformers version to a given reference with an operation. Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version) return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str): def is_hf_hub_version(operation: str, version: str):
""" """
Compares the current Hugging Face Hub version to a given reference with an operation. Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version) return compare_versions(parse(_hf_hub_version), operation, version)
@cache
def is_accelerate_version(operation: str, version: str): def is_accelerate_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version) return compare_versions(parse(_accelerate_version), operation, version)
@cache
def is_peft_version(operation: str, version: str): def is_peft_version(operation: str, version: str):
""" """
Compares the current PEFT version to a given reference with an operation. Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version) return compare_versions(parse(_peft_version), operation, version)
@cache
def is_bitsandbytes_version(operation: str, version: str): def is_bitsandbytes_version(operation: str, version: str):
""" """
Args: Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version) return compare_versions(parse(_bitsandbytes_version), operation, version)
@cache
def is_gguf_version(operation: str, version: str): def is_gguf_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version) return compare_versions(parse(_gguf_version), operation, version)
@cache
def is_torchao_version(operation: str, version: str): def is_torchao_version(operation: str, version: str):
""" """
Compares the current torchao version to a given reference with an operation. Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version) return compare_versions(parse(_torchao_version), operation, version)
@cache
def is_k_diffusion_version(operation: str, version: str): def is_k_diffusion_version(operation: str, version: str):
""" """
Compares the current k-diffusion version to a given reference with an operation. Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version) return compare_versions(parse(_k_diffusion_version), operation, version)
@cache
def is_optimum_quanto_version(operation: str, version: str): def is_optimum_quanto_version(operation: str, version: str):
""" """
Compares the current Accelerate version to a given reference with an operation. Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version) return compare_versions(parse(_optimum_quanto_version), operation, version)
@cache
def is_nvidia_modelopt_version(operation: str, version: str): def is_nvidia_modelopt_version(operation: str, version: str):
""" """
Compares the current Nvidia ModelOpt version to a given reference with an operation. Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version) return compare_versions(parse(_nvidia_modelopt_version), operation, version)
@cache
def is_xformers_version(operation: str, version: str): def is_xformers_version(operation: str, version: str):
""" """
Compares the current xformers version to a given reference with an operation. Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version) return compare_versions(parse(_xformers_version), operation, version)
@cache
def is_sageattention_version(operation: str, version: str): def is_sageattention_version(operation: str, version: str):
""" """
Compares the current sageattention version to a given reference with an operation. Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version) return compare_versions(parse(_sageattention_version), operation, version)
@cache
def is_flash_attn_version(operation: str, version: str): def is_flash_attn_version(operation: str, version: str):
""" """
Compares the current flash-attention version to a given reference with an operation. Compares the current flash-attention version to a given reference with an operation.