mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
10 Commits
group-memo
...
dduf-with-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cd5155bb8 | ||
|
|
b14bffeffe | ||
|
|
e66c4d0dab | ||
|
|
7d2c7d5553 | ||
|
|
78135f1478 | ||
|
|
d8408677c5 | ||
|
|
cbee7cbc6b | ||
|
|
2eeda25321 | ||
|
|
0389333113 | ||
|
|
1fb86e34c0 |
@@ -347,6 +347,7 @@ class ConfigMixin:
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
@@ -358,8 +359,22 @@ class ConfigMixin:
|
||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||
)
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
# Custom path for now
|
||||
if dduf_reader:
|
||||
if subfolder is not None:
|
||||
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
|
||||
config_file = os.path.join(subfolder, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if subfolder is not None and os.path.isfile(
|
||||
@@ -426,10 +441,8 @@ class ConfigMixin:
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {cls.config_name} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
@@ -552,9 +565,12 @@ class ConfigMixin:
|
||||
return init_dict, unused_kwargs, hidden_config_dict
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
|
||||
if dduf_reader:
|
||||
text = dduf_reader.read_file(json_file, encoding="utf-8")
|
||||
else:
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -128,7 +128,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
||||
return old_class
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
|
||||
"""
|
||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@@ -138,8 +138,15 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
||||
return checkpoint_file
|
||||
try:
|
||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||
if dduf_reader:
|
||||
checkpoint_file = dduf_reader.read_file(checkpoint_file)
|
||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
if dduf_reader:
|
||||
# tensors are loaded on cpu
|
||||
return safetensors.torch.load(checkpoint_file)
|
||||
else:
|
||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||
|
||||
else:
|
||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||
return torch.load(
|
||||
@@ -272,6 +279,7 @@ def _fetch_index_file(
|
||||
revision,
|
||||
user_agent,
|
||||
commit_hash,
|
||||
dduf_reader=None,
|
||||
):
|
||||
if is_local:
|
||||
index_file = Path(
|
||||
@@ -297,6 +305,7 @@ def _fetch_index_file(
|
||||
subfolder=None,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
index_file = Path(index_file)
|
||||
except (EntryNotFoundError, EnvironmentError):
|
||||
|
||||
@@ -557,6 +557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -649,6 +650,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
dduf_reader=dduf_reader,
|
||||
**kwargs,
|
||||
)
|
||||
# no in-place modification of the original config.
|
||||
@@ -724,6 +726,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"revision": revision,
|
||||
"user_agent": user_agent,
|
||||
"commit_hash": commit_hash,
|
||||
"dduf_reader": dduf_reader,
|
||||
}
|
||||
index_file = _fetch_index_file(**index_file_kwargs)
|
||||
# In case the index file was not found we still have to consider the legacy format.
|
||||
@@ -759,7 +762,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if is_sharded:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded and not dduf_reader:
|
||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||
pretrained_model_name_or_path,
|
||||
index_file,
|
||||
@@ -790,6 +794,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
@@ -813,6 +818,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
@@ -837,7 +843,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||
elif is_quant_method_bnb:
|
||||
param_device = torch.cuda.current_device()
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# move the params from meta device to cpu
|
||||
@@ -937,7 +943,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
|
||||
@@ -627,6 +627,7 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
dduf_reader,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
|
||||
@@ -721,7 +722,10 @@ def load_sub_model(
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
if dduf_reader:
|
||||
loading_kwargs["dduf_reader"] = dduf_reader
|
||||
loaded_sub_model = load_method(name, **loading_kwargs)
|
||||
elif os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||
else:
|
||||
# else load from the root directory
|
||||
|
||||
@@ -50,6 +50,7 @@ from ..utils import (
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
BaseOutput,
|
||||
DDUFReader,
|
||||
PushToHubMixin,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -193,6 +194,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = None,
|
||||
push_to_hub: bool = False,
|
||||
dduf_filename: Optional[Union[str, os.PathLike]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -301,9 +303,56 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
if dduf_filename:
|
||||
import shutil
|
||||
import zipfile
|
||||
|
||||
def zipdir(dir_to_archive, zipf):
|
||||
"""Archive a directory"""
|
||||
for root, dirs, files in os.walk(dir_to_archive):
|
||||
for file in files:
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.join(
|
||||
os.path.basename(dir_to_archive), os.path.relpath(file_path, start=dir_to_archive)
|
||||
)
|
||||
zipf.write(file_path, arcname=arcname)
|
||||
|
||||
dduf_file_path = os.path.join(save_directory, dduf_filename)
|
||||
|
||||
if os.path.isdir(dduf_file_path):
|
||||
logger.warning(
|
||||
f"Removing the existing folder {dduf_file_path} so that we can save the DDUF archive."
|
||||
)
|
||||
shutil.rmtree(dduf_file_path)
|
||||
if (
|
||||
os.path.exists(dduf_file_path)
|
||||
and os.path.isfile(dduf_file_path)
|
||||
and zipfile.is_zipfile(dduf_file_path)
|
||||
):
|
||||
# Open in append mode if the file exists
|
||||
mode = "a"
|
||||
else:
|
||||
# Open in write mode to create it if it doesn't exist
|
||||
mode = "w"
|
||||
with zipfile.ZipFile(dduf_file_path, mode=mode, compression=zipfile.ZIP_STORED) as zipf:
|
||||
dir_to_archive = os.path.join(save_directory, pipeline_component_name)
|
||||
if os.path.isdir(dir_to_archive):
|
||||
zipdir(dir_to_archive, zipf)
|
||||
shutil.rmtree(dir_to_archive)
|
||||
|
||||
# finally save the config
|
||||
self.save_config(save_directory)
|
||||
|
||||
# Takes care of including the "model_index.json" inside the ZIP.
|
||||
# TODO: Include a DDUF a metadata file.
|
||||
if dduf_filename:
|
||||
import zipfile
|
||||
|
||||
with zipfile.ZipFile(dduf_file_path, mode="a", compression=zipfile.ZIP_STORED) as zipf:
|
||||
config_path = os.path.join(save_directory, self.config_name)
|
||||
zipf.write(config_path, arcname=os.path.basename(config_path))
|
||||
os.remove(config_path)
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
@@ -523,6 +572,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||
saved using
|
||||
[`~DiffusionPipeline.save_pretrained`].
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf archive or
|
||||
folder
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -617,6 +668,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
dduf (`str`, *optional*):
|
||||
Load weights from the specified dduf archive or folder.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -666,6 +719,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf = kwargs.pop("dduf", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
@@ -736,6 +790,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
dduf=dduf,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -757,7 +812,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
logger.warning(warn_msg)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
dduf_reader = None
|
||||
if dduf:
|
||||
zip_file_path = os.path.join(cached_folder, dduf)
|
||||
dduf_reader = DDUFReader(zip_file_path)
|
||||
# The reader contains already all the files needed, no need to check it again
|
||||
cached_folder = ""
|
||||
|
||||
config_dict = cls.load_config(cached_folder, dduf_reader=dduf_reader)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
@@ -914,6 +976,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_reader=dduf_reader,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1227,6 +1290,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant (`str`, *optional*):
|
||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||
loading `from_flax`.
|
||||
dduf(`str`, *optional*):
|
||||
Load weights from the specified DDUF archive or folder.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
@@ -1267,6 +1332,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
dduf = kwargs.pop("dduf", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
@@ -1285,7 +1351,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
local_files_only = True
|
||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||
|
||||
if not local_files_only:
|
||||
if dduf is not None and not local_files_only:
|
||||
dduf_available = False
|
||||
for sibling in info.siblings:
|
||||
dduf_available = dduf in sibling.rfilename
|
||||
if not dduf_available:
|
||||
raise ValueError(f"Requested {dduf} file is not available in {pretrained_model_name}.")
|
||||
|
||||
if not local_files_only and not dduf:
|
||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||
warn_msg = (
|
||||
@@ -1346,7 +1419,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||
# also allow downloading config.json files with the model
|
||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||
|
||||
allow_patterns += [
|
||||
SCHEDULER_CONFIG_NAME,
|
||||
CONFIG_NAME,
|
||||
@@ -1425,10 +1497,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
return snapshot_folder
|
||||
|
||||
user_agent = {"pipeline_class": cls.__name__}
|
||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||
if not dduf and custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||
user_agent["custom_pipeline"] = custom_pipeline
|
||||
|
||||
# download all allow_patterns - ignore_patterns
|
||||
# also allow downloading the dduf
|
||||
if dduf is not None:
|
||||
allow_patterns = [dduf]
|
||||
ignore_patterns = []
|
||||
try:
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name,
|
||||
@@ -1443,26 +1519,27 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
# retrieve pipeline class from local file
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
if not dduf:
|
||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
|
||||
|
||||
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
|
||||
for connected_pipe_repo_id in connected_pipes:
|
||||
download_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
|
||||
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
|
||||
for connected_pipe_repo_id in connected_pipes:
|
||||
download_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
|
||||
|
||||
return cached_folder
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||
from .modeling_stable_audio import StableAudioProjectionModel
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@@ -732,7 +733,7 @@ class StableAudioPipeline(DiffusionPipeline):
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from .constants import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
)
|
||||
from .dduf import DDUFReader
|
||||
from .deprecation_utils import deprecate
|
||||
from .doc_utils import replace_example_docstring
|
||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||
|
||||
41
src/diffusers/utils/dduf.py
Normal file
41
src/diffusers/utils/dduf.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import zipfile
|
||||
|
||||
|
||||
class DDUFReader:
|
||||
def __init__(self, dduf_file):
|
||||
self.dduf_file = dduf_file
|
||||
self.files = []
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
"""
|
||||
Check that the DDUF file is valid
|
||||
"""
|
||||
if not zipfile.is_zipfile(self.dduf_file):
|
||||
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(self.dduf_file, "r") as zf:
|
||||
# Check integrity and store file list
|
||||
zf.testzip() # Returns None if no corrupt files are found
|
||||
self.files = zf.namelist()
|
||||
except zipfile.BadZipFile:
|
||||
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"An error occurred while validating the ZIP file: {e}")
|
||||
|
||||
def has_file(self, file):
|
||||
return file in self.files
|
||||
|
||||
def read_file(self, file_name, encoding=None):
|
||||
"""
|
||||
Reads the content of a specific file in the ZIP archive without extracting.
|
||||
"""
|
||||
if file_name not in self.files:
|
||||
raise ValueError(f"{file_name} is not in the list of files {self.files}")
|
||||
with zipfile.ZipFile(self.dduf_file, "r") as zf:
|
||||
with zf.open(file_name) as file:
|
||||
file = file.read()
|
||||
if encoding is not None:
|
||||
file = file.decode(encoding)
|
||||
return file
|
||||
@@ -291,9 +291,20 @@ def _get_model_file(
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
revision: Optional[str] = None,
|
||||
commit_hash: Optional[str] = None,
|
||||
dduf_reader=None,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
|
||||
if dduf_reader:
|
||||
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
return os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
return os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
else:
|
||||
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_reader.files}.")
|
||||
elif os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
|
||||
Reference in New Issue
Block a user