|
|
|
|
@@ -20,7 +20,6 @@ import json
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import shutil
|
|
|
|
|
import signal
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
@@ -34,6 +33,7 @@ from packaging import version
|
|
|
|
|
|
|
|
|
|
from .. import __version__
|
|
|
|
|
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
|
|
|
|
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
@@ -159,52 +159,25 @@ def check_imports(filename):
|
|
|
|
|
return get_relative_imports(filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _raise_timeout_error(signum, frame):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
|
|
|
|
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
|
|
|
|
if trust_remote_code is None:
|
|
|
|
|
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
|
|
|
|
prev_sig_handler = None
|
|
|
|
|
try:
|
|
|
|
|
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
|
|
|
|
signal.alarm(TIME_OUT_REMOTE_CODE)
|
|
|
|
|
while trust_remote_code is None:
|
|
|
|
|
answer = input(
|
|
|
|
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
|
|
|
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
|
|
|
|
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
|
|
|
|
f"Do you wish to run the custom code? [y/N] "
|
|
|
|
|
)
|
|
|
|
|
if answer.lower() in ["yes", "y", "1"]:
|
|
|
|
|
trust_remote_code = True
|
|
|
|
|
elif answer.lower() in ["no", "n", "0", ""]:
|
|
|
|
|
trust_remote_code = False
|
|
|
|
|
signal.alarm(0)
|
|
|
|
|
except Exception:
|
|
|
|
|
# OS which does not support signal.SIGALRM
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
|
|
|
|
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
|
|
|
|
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
|
|
|
|
)
|
|
|
|
|
finally:
|
|
|
|
|
if prev_sig_handler is not None:
|
|
|
|
|
signal.signal(signal.SIGALRM, prev_sig_handler)
|
|
|
|
|
signal.alarm(0)
|
|
|
|
|
elif has_remote_code:
|
|
|
|
|
# For the CI which puts the timeout at 0
|
|
|
|
|
_raise_timeout_error(None, None)
|
|
|
|
|
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
|
|
|
|
|
if DIFFUSERS_DISABLE_REMOTE_CODE:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if has_remote_code and not trust_remote_code:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Loading {model_name} requires you to execute the configuration file in that"
|
|
|
|
|
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
|
|
|
|
" set the option `trust_remote_code=True` to remove this error."
|
|
|
|
|
error_msg = f"The repository for {model_name} contains custom code. "
|
|
|
|
|
error_msg += (
|
|
|
|
|
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
|
|
|
|
|
if DIFFUSERS_DISABLE_REMOTE_CODE
|
|
|
|
|
else "Pass `trust_remote_code=True` to allow loading remote code modules."
|
|
|
|
|
)
|
|
|
|
|
raise ValueError(error_msg)
|
|
|
|
|
|
|
|
|
|
elif has_remote_code and trust_remote_code:
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return trust_remote_code
|
|
|
|
|
|