mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-22 02:39:51 +08:00
Compare commits
6 Commits
modular-sa
...
update-ker
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b7ed4c8dc | ||
|
|
67f4691cab | ||
|
|
e10fe61303 | ||
|
|
348350cf24 | ||
|
|
af35e3806c | ||
|
|
d6bc647932 |
@@ -38,6 +38,7 @@ from ..utils import (
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_kernels_available,
|
||||
is_kernels_version,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_torch_npu_available,
|
||||
@@ -265,6 +266,7 @@ class _HubKernelConfig:
|
||||
repo_id: str
|
||||
function_attr: str
|
||||
revision: str | None = None
|
||||
version: int | None = None
|
||||
kernel_fn: Callable | None = None
|
||||
wrapped_forward_attr: str | None = None
|
||||
wrapped_backward_attr: str | None = None
|
||||
@@ -274,27 +276,31 @@ class _HubKernelConfig:
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
|
||||
),
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
function_attr="flash_attn_varlen_func",
|
||||
# revision="fake-ops-return-probs",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
version=1,
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_varlen_func",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
repo_id="kernels-community/sage-attention",
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -464,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
if not is_kernels_version(">=", "0.12"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
@@ -29,16 +28,10 @@ from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..pipelines.pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_unwrap_model,
|
||||
simple_get_class_obj,
|
||||
)
|
||||
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
|
||||
from ..utils import PushToHubMixin, is_accelerate_available, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
MODULAR_MODEL_CARD_TEMPLATE,
|
||||
@@ -1826,111 +1819,29 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | os.PathLike,
|
||||
safe_serialization: bool = True,
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
|
||||
[`~ModularPipeline.from_pretrained`] class method.
|
||||
Save the pipeline to a directory. It does not save components, you need to save them separately.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save the pipeline to. Will be created if it doesn't exist.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
max_shard_size (`int` or `str`, defaults to `None`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
||||
If expressed as an integer, the unit is bytes.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the pipeline to the Hugging Face model hub after saving it.
|
||||
**kwargs: Additional keyword arguments passed along to the push to hub method.
|
||||
Path to the directory where the pipeline will be saved.
|
||||
push_to_hub (`bool`, optional):
|
||||
Whether to push the pipeline to the huggingface hub.
|
||||
**kwargs: Additional arguments passed to `save_config()` method
|
||||
"""
|
||||
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
|
||||
for component_name, component_spec in self._component_specs.items():
|
||||
sub_model = getattr(self, component_name, None)
|
||||
if sub_model is None:
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
if is_compiled_module(sub_model):
|
||||
sub_model = _unwrap_model(sub_model)
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warning(f"self.{component_name}={sub_model} of type {type(sub_model)} cannot be saved.")
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
save_method_signature = inspect.signature(save_method)
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
save_kwargs["safe_serialization"] = safe_serialization
|
||||
if save_method_accept_variant:
|
||||
save_kwargs["variant"] = variant
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
|
||||
save_method(os.path.join(save_directory, component_name), **save_kwargs)
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
if overwrite_modular_index:
|
||||
for component_name, component_spec in self._component_specs.items():
|
||||
if component_spec.default_creation_method != "from_pretrained":
|
||||
continue
|
||||
if component_name not in self.config:
|
||||
continue
|
||||
|
||||
library, class_name, component_spec_dict = self.config[component_name]
|
||||
component_spec_dict["pretrained_model_name_or_path"] = repo_id
|
||||
component_spec_dict["subfolder"] = component_name
|
||||
if variant is not None and "variant" in component_spec_dict:
|
||||
component_spec_dict["variant"] = variant
|
||||
|
||||
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
|
||||
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
# Generate modular pipeline card content
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id,
|
||||
token=token,
|
||||
@@ -1939,8 +1850,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_modular=True,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
|
||||
@@ -86,6 +86,7 @@ from .import_utils import (
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_kernels_available,
|
||||
is_kernels_version,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
|
||||
@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_kernels_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Kernels version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _kernels_available:
|
||||
return False
|
||||
return compare_versions(parse(_kernels_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_hf_hub_version(operation: str, version: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user