mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-09 14:04:37 +08:00
Compare commits
4 Commits
sayakpaul-
...
enable-tel
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e187e239d6 | ||
|
|
59f4531a55 | ||
|
|
ff80e8a27f | ||
|
|
fe176857a2 |
@@ -21,6 +21,7 @@ import torch
|
|||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from .. import __version__
|
||||||
from ..quantizers import DiffusersAutoQuantizer
|
from ..quantizers import DiffusersAutoQuantizer
|
||||||
from ..utils import deprecate, is_accelerate_available, logging
|
from ..utils import deprecate, is_accelerate_available, logging
|
||||||
from .single_file_utils import (
|
from .single_file_utils import (
|
||||||
@@ -260,6 +261,11 @@ class FromOriginalModelMixin:
|
|||||||
device = kwargs.pop("device", None)
|
device = kwargs.pop("device", None)
|
||||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||||
|
|
||||||
|
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||||
|
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||||
|
if quantization_config is not None:
|
||||||
|
user_agent["quant"] = quantization_config.quant_method.value
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -278,6 +284,7 @@ class FromOriginalModelMixin:
|
|||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
disable_mmap=disable_mmap,
|
disable_mmap=disable_mmap,
|
||||||
|
user_agent=user_agent,
|
||||||
)
|
)
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||||
|
|||||||
@@ -405,13 +405,16 @@ def load_single_file_checkpoint(
|
|||||||
local_files_only=None,
|
local_files_only=None,
|
||||||
revision=None,
|
revision=None,
|
||||||
disable_mmap=False,
|
disable_mmap=False,
|
||||||
|
user_agent=None,
|
||||||
):
|
):
|
||||||
|
if user_agent is None:
|
||||||
|
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
||||||
|
|
||||||
if os.path.isfile(pretrained_model_link_or_path):
|
if os.path.isfile(pretrained_model_link_or_path):
|
||||||
pretrained_model_link_or_path = pretrained_model_link_or_path
|
pretrained_model_link_or_path = pretrained_model_link_or_path
|
||||||
|
|
||||||
else:
|
else:
|
||||||
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
||||||
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
|
||||||
pretrained_model_link_or_path = _get_model_file(
|
pretrained_model_link_or_path = _get_model_file(
|
||||||
repo_id,
|
repo_id,
|
||||||
weights_name=weights_name,
|
weights_name=weights_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user