Compare commits

...

4 Commits

Author SHA1 Message Date
DN6
a4415e2bd6 update 2025-10-23 08:19:12 +05:30
Dhruv Nair
6098e45b36 Merge branch 'main' into transformer-clip-fix 2025-10-22 16:13:16 +05:30
DN6
027e387bb1 update 2025-10-21 11:15:12 +05:30
DN6
6a0a3c0462 update 2025-10-21 11:09:09 +05:30
3 changed files with 64 additions and 1 deletions

View File

@@ -33,6 +33,7 @@ from ..utils import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_maybe_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
@@ -356,6 +357,11 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -390,6 +396,11 @@ def simple_get_class_obj(library_name, class_name):
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
return class_obj
@@ -416,6 +427,10 @@ def get_class_obj_and_candidates(
# else we just import it from the library.
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}

View File

@@ -38,7 +38,7 @@ from .constants import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video

View File

@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union
from packaging import version
from ..utils import logging
logger = logging.get_logger(__name__)
# Mapping for deprecated Transformers classes to their replacements
# This is used to handle models that reference deprecated class names in their configs
# Reference: https://github.com/huggingface/transformers/issues/40822
# Format: {
# "DeprecatedClassName": {
# "new_class": "NewClassName",
# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
# }
# }
_TRANSFORMERS_CLASS_REMAPPING = {
"CLIPFeatureExtractor": {
"new_class": "CLIPImageProcessor",
"transformers_version": (">", "4.57.0"),
},
}
def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
"""
Check if a Transformers class should be remapped to a newer version.
Args:
class_name: The name of the class to check
Returns:
The new class name if remapping should occur, None otherwise
"""
if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
return None
from .import_utils import is_transformers_version
mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
operation, required_version = mapping["transformers_version"]
# Only remap if the transformers version meets the requirement
if is_transformers_version(operation, required_version):
new_class = mapping["new_class"]
logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
return mapping["new_class"]
return None
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__