mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
4 Commits
pr-tests-f
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4415e2bd6 | ||
|
|
6098e45b36 | ||
|
|
027e387bb1 | ||
|
|
6a0a3c0462 |
@@ -33,6 +33,7 @@ from ..utils import (
|
||||
ONNX_WEIGHTS_NAME,
|
||||
SAFETENSORS_WEIGHTS_NAME,
|
||||
WEIGHTS_NAME,
|
||||
_maybe_remap_transformers_class,
|
||||
deprecate,
|
||||
get_class_from_dynamic_module,
|
||||
is_accelerate_available,
|
||||
@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
|
||||
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
||||
if not is_pipeline_module:
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
return class_obj
|
||||
@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
|
||||
# Handle deprecated Transformers classes
|
||||
if library_name == "transformers":
|
||||
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
||||
|
||||
class_obj = getattr(library, class_name)
|
||||
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from .constants import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from .deprecation_utils import deprecate
|
||||
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
|
||||
from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
|
||||
|
||||
@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Mapping for deprecated Transformers classes to their replacements
|
||||
# This is used to handle models that reference deprecated class names in their configs
|
||||
# Reference: https://github.com/huggingface/transformers/issues/40822
|
||||
# Format: {
|
||||
# "DeprecatedClassName": {
|
||||
# "new_class": "NewClassName",
|
||||
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
|
||||
# }
|
||||
# }
|
||||
_TRANSFORMERS_CLASS_REMAPPING = {
|
||||
"CLIPFeatureExtractor": {
|
||||
"new_class": "CLIPImageProcessor",
|
||||
"transformers_version": (">", "4.57.0"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
|
||||
"""
|
||||
Check if a Transformers class should be remapped to a newer version.
|
||||
|
||||
Args:
|
||||
class_name: The name of the class to check
|
||||
|
||||
Returns:
|
||||
The new class name if remapping should occur, None otherwise
|
||||
"""
|
||||
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
|
||||
return None
|
||||
|
||||
from .import_utils import is_transformers_version
|
||||
|
||||
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
|
||||
operation, required_version = mapping["transformers_version"]
|
||||
|
||||
# Only remap if the transformers version meets the requirement
|
||||
if is_transformers_version(operation, required_version):
|
||||
new_class = mapping["new_class"]
|
||||
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
|
||||
return mapping["new_class"]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
|
||||
from .. import __version__
|
||||
|
||||
Reference in New Issue
Block a user