mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 19:04:49 +08:00
Compare commits
4 Commits
enable-cp-
...
custom-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50c7ddeaea | ||
|
|
4b2b2b221b | ||
|
|
0db2ea2bc8 | ||
|
|
1b26e309f4 |
@@ -290,7 +290,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
|||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
pretrained_model_name_or_path: str,
|
pretrained_model_name_or_path: str,
|
||||||
trust_remote_code: Optional[bool] = None,
|
trust_remote_code: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
hub_kwargs_names = [
|
hub_kwargs_names = [
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
|||||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
||||||
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||||
|
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
|
||||||
|
|
||||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import json
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import signal
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -34,6 +33,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||||
|
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -159,52 +159,25 @@ def check_imports(filename):
|
|||||||
return get_relative_imports(filename)
|
return get_relative_imports(filename)
|
||||||
|
|
||||||
|
|
||||||
def _raise_timeout_error(signum, frame):
|
|
||||||
raise ValueError(
|
|
||||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
|
||||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||||
if trust_remote_code is None:
|
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
|
||||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
if DIFFUSERS_DISABLE_REMOTE_CODE:
|
||||||
prev_sig_handler = None
|
logger.warning(
|
||||||
try:
|
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
|
||||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
)
|
||||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
|
||||||
while trust_remote_code is None:
|
|
||||||
answer = input(
|
|
||||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
|
||||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
|
||||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
|
||||||
f"Do you wish to run the custom code? [y/N] "
|
|
||||||
)
|
|
||||||
if answer.lower() in ["yes", "y", "1"]:
|
|
||||||
trust_remote_code = True
|
|
||||||
elif answer.lower() in ["no", "n", "0", ""]:
|
|
||||||
trust_remote_code = False
|
|
||||||
signal.alarm(0)
|
|
||||||
except Exception:
|
|
||||||
# OS which does not support signal.SIGALRM
|
|
||||||
raise ValueError(
|
|
||||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
|
||||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
|
||||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if prev_sig_handler is not None:
|
|
||||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
|
||||||
signal.alarm(0)
|
|
||||||
elif has_remote_code:
|
|
||||||
# For the CI which puts the timeout at 0
|
|
||||||
_raise_timeout_error(None, None)
|
|
||||||
|
|
||||||
if has_remote_code and not trust_remote_code:
|
if has_remote_code and not trust_remote_code:
|
||||||
raise ValueError(
|
error_msg = f"The repository for {model_name} contains custom code. "
|
||||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
error_msg += (
|
||||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
|
||||||
" set the option `trust_remote_code=True` to remove this error."
|
if DIFFUSERS_DISABLE_REMOTE_CODE
|
||||||
|
else "Pass `trust_remote_code=True` to allow loading remote code modules."
|
||||||
|
)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
|
elif has_remote_code and trust_remote_code:
|
||||||
|
logger.warning(
|
||||||
|
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
return trust_remote_code
|
return trust_remote_code
|
||||||
|
|||||||
Reference in New Issue
Block a user