Compare commits

...

4 Commits

Author SHA1 Message Date
DN6
50c7ddeaea update 2025-08-18 14:02:05 +05:30
DN6
4b2b2b221b update 2025-08-18 13:29:07 +05:30
DN6
0db2ea2bc8 update 2025-08-18 13:20:59 +05:30
DN6
1b26e309f4 update 2025-08-18 11:40:02 +05:30
3 changed files with 19 additions and 45 deletions

View File

@@ -290,7 +290,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def from_pretrained( def from_pretrained(
cls, cls,
pretrained_model_name_or_path: str, pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None, trust_remote_code: bool = False,
**kwargs, **kwargs,
): ):
hub_kwargs_names = [ hub_kwargs_names = [

View File

@@ -45,6 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with # Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are

View File

@@ -20,7 +20,6 @@ import json
import os import os
import re import re
import shutil import shutil
import signal
import sys import sys
import threading import threading
from pathlib import Path from pathlib import Path
@@ -34,6 +33,7 @@ from packaging import version
from .. import __version__ from .. import __version__
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -159,52 +159,25 @@ def check_imports(filename):
return get_relative_imports(filename) return get_relative_imports(filename)
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
if trust_remote_code is None: trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
if has_remote_code and TIME_OUT_REMOTE_CODE > 0: if DIFFUSERS_DISABLE_REMOTE_CODE:
prev_sig_handler = None logger.warning(
try: "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error) )
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code: if has_remote_code and not trust_remote_code:
raise ValueError( error_msg = f"The repository for {model_name} contains custom code. "
f"Loading {model_name} requires you to execute the configuration file in that" error_msg += (
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then" "Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
" set the option `trust_remote_code=True` to remove this error." if DIFFUSERS_DISABLE_REMOTE_CODE
else "Pass `trust_remote_code=True` to allow loading remote code modules."
)
raise ValueError(error_msg)
elif has_remote_code and trust_remote_code:
logger.warning(
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
) )
return trust_remote_code return trust_remote_code