mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 16:34:27 +08:00
Compare commits
10 Commits
fast-gpu-t
...
dduf-with-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1cd5155bb8 | ||
|
|
b14bffeffe | ||
|
|
e66c4d0dab | ||
|
|
7d2c7d5553 | ||
|
|
78135f1478 | ||
|
|
d8408677c5 | ||
|
|
cbee7cbc6b | ||
|
|
2eeda25321 | ||
|
|
0389333113 | ||
|
|
1fb86e34c0 |
@@ -347,6 +347,7 @@ class ConfigMixin:
|
|||||||
_ = kwargs.pop("mirror", None)
|
_ = kwargs.pop("mirror", None)
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
user_agent = kwargs.pop("user_agent", {})
|
user_agent = kwargs.pop("user_agent", {})
|
||||||
|
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||||
|
|
||||||
user_agent = {**user_agent, "file_type": "config"}
|
user_agent = {**user_agent, "file_type": "config"}
|
||||||
user_agent = http_user_agent(user_agent)
|
user_agent = http_user_agent(user_agent)
|
||||||
@@ -358,8 +359,22 @@ class ConfigMixin:
|
|||||||
"`self.config_name` is not defined. Note that one should not load a config from "
|
"`self.config_name` is not defined. Note that one should not load a config from "
|
||||||
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
||||||
)
|
)
|
||||||
|
# Custom path for now
|
||||||
if os.path.isfile(pretrained_model_name_or_path):
|
if dduf_reader:
|
||||||
|
if subfolder is not None:
|
||||||
|
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
|
||||||
|
config_file = os.path.join(subfolder, cls.config_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||||
|
)
|
||||||
|
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
||||||
|
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
|
||||||
|
)
|
||||||
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
config_file = pretrained_model_name_or_path
|
config_file = pretrained_model_name_or_path
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
if subfolder is not None and os.path.isfile(
|
if subfolder is not None and os.path.isfile(
|
||||||
@@ -426,10 +441,8 @@ class ConfigMixin:
|
|||||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||||
f"containing a {cls.config_name} file"
|
f"containing a {cls.config_name} file"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load config dict
|
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
|
||||||
config_dict = cls._dict_from_json_file(config_file)
|
|
||||||
|
|
||||||
commit_hash = extract_commit_hash(config_file)
|
commit_hash = extract_commit_hash(config_file)
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
@@ -552,9 +565,12 @@ class ConfigMixin:
|
|||||||
return init_dict, unused_kwargs, hidden_config_dict
|
return init_dict, unused_kwargs, hidden_config_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
|
||||||
with open(json_file, "r", encoding="utf-8") as reader:
|
if dduf_reader:
|
||||||
text = reader.read()
|
text = dduf_reader.read_file(json_file, encoding="utf-8")
|
||||||
|
else:
|
||||||
|
with open(json_file, "r", encoding="utf-8") as reader:
|
||||||
|
text = reader.read()
|
||||||
return json.loads(text)
|
return json.loads(text)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
|||||||
return old_class
|
return old_class
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
|
||||||
"""
|
"""
|
||||||
Reads a checkpoint file, returning properly formatted errors if they arise.
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
||||||
"""
|
"""
|
||||||
@@ -138,8 +138,15 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|||||||
return checkpoint_file
|
return checkpoint_file
|
||||||
try:
|
try:
|
||||||
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
||||||
|
if dduf_reader:
|
||||||
|
checkpoint_file = dduf_reader.read_file(checkpoint_file)
|
||||||
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
||||||
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
if dduf_reader:
|
||||||
|
# tensors are loaded on cpu
|
||||||
|
return safetensors.torch.load(checkpoint_file)
|
||||||
|
else:
|
||||||
|
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
||||||
return torch.load(
|
return torch.load(
|
||||||
@@ -272,6 +279,7 @@ def _fetch_index_file(
|
|||||||
revision,
|
revision,
|
||||||
user_agent,
|
user_agent,
|
||||||
commit_hash,
|
commit_hash,
|
||||||
|
dduf_reader=None,
|
||||||
):
|
):
|
||||||
if is_local:
|
if is_local:
|
||||||
index_file = Path(
|
index_file = Path(
|
||||||
@@ -297,6 +305,7 @@ def _fetch_index_file(
|
|||||||
subfolder=None,
|
subfolder=None,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
|
dduf_reader=dduf_reader,
|
||||||
)
|
)
|
||||||
index_file = Path(index_file)
|
index_file = Path(index_file)
|
||||||
except (EntryNotFoundError, EnvironmentError):
|
except (EntryNotFoundError, EnvironmentError):
|
||||||
|
|||||||
@@ -557,6 +557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
quantization_config = kwargs.pop("quantization_config", None)
|
quantization_config = kwargs.pop("quantization_config", None)
|
||||||
|
dduf_reader = kwargs.pop("dduf_reader", None)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -649,6 +650,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
|
dduf_reader=dduf_reader,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# no in-place modification of the original config.
|
# no in-place modification of the original config.
|
||||||
@@ -724,6 +726,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
"revision": revision,
|
"revision": revision,
|
||||||
"user_agent": user_agent,
|
"user_agent": user_agent,
|
||||||
"commit_hash": commit_hash,
|
"commit_hash": commit_hash,
|
||||||
|
"dduf_reader": dduf_reader,
|
||||||
}
|
}
|
||||||
index_file = _fetch_index_file(**index_file_kwargs)
|
index_file = _fetch_index_file(**index_file_kwargs)
|
||||||
# In case the index file was not found we still have to consider the legacy format.
|
# In case the index file was not found we still have to consider the legacy format.
|
||||||
@@ -759,7 +762,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
|
|
||||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||||
else:
|
else:
|
||||||
if is_sharded:
|
# in the case it is sharded, we have already the index
|
||||||
|
if is_sharded and not dduf_reader:
|
||||||
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
index_file,
|
index_file,
|
||||||
@@ -790,6 +794,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
|
dduf_reader=dduf_reader,
|
||||||
)
|
)
|
||||||
|
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
@@ -813,6 +818,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
|
dduf_reader=dduf_reader,
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
@@ -837,7 +843,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
||||||
elif is_quant_method_bnb:
|
elif is_quant_method_bnb:
|
||||||
param_device = torch.cuda.current_device()
|
param_device = torch.cuda.current_device()
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
|
|
||||||
# move the params from meta device to cpu
|
# move the params from meta device to cpu
|
||||||
@@ -937,7 +943,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
|
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
|
|||||||
@@ -627,6 +627,7 @@ def load_sub_model(
|
|||||||
low_cpu_mem_usage: bool,
|
low_cpu_mem_usage: bool,
|
||||||
cached_folder: Union[str, os.PathLike],
|
cached_folder: Union[str, os.PathLike],
|
||||||
use_safetensors: bool,
|
use_safetensors: bool,
|
||||||
|
dduf_reader,
|
||||||
):
|
):
|
||||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||||
|
|
||||||
@@ -721,7 +722,10 @@ def load_sub_model(
|
|||||||
loading_kwargs["low_cpu_mem_usage"] = False
|
loading_kwargs["low_cpu_mem_usage"] = False
|
||||||
|
|
||||||
# check if the module is in a subdirectory
|
# check if the module is in a subdirectory
|
||||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
if dduf_reader:
|
||||||
|
loading_kwargs["dduf_reader"] = dduf_reader
|
||||||
|
loaded_sub_model = load_method(name, **loading_kwargs)
|
||||||
|
elif os.path.isdir(os.path.join(cached_folder, name)):
|
||||||
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
|
||||||
else:
|
else:
|
||||||
# else load from the root directory
|
# else load from the root directory
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from ..utils import (
|
|||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DEPRECATED_REVISION_ARGS,
|
DEPRECATED_REVISION_ARGS,
|
||||||
BaseOutput,
|
BaseOutput,
|
||||||
|
DDUFReader,
|
||||||
PushToHubMixin,
|
PushToHubMixin,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_accelerate_version,
|
is_accelerate_version,
|
||||||
@@ -193,6 +194,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
max_shard_size: Optional[Union[int, str]] = None,
|
max_shard_size: Optional[Union[int, str]] = None,
|
||||||
push_to_hub: bool = False,
|
push_to_hub: bool = False,
|
||||||
|
dduf_filename: Optional[Union[str, os.PathLike]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -301,9 +303,56 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
|
|
||||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||||
|
|
||||||
|
if dduf_filename:
|
||||||
|
import shutil
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
def zipdir(dir_to_archive, zipf):
|
||||||
|
"""Archive a directory"""
|
||||||
|
for root, dirs, files in os.walk(dir_to_archive):
|
||||||
|
for file in files:
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.join(
|
||||||
|
os.path.basename(dir_to_archive), os.path.relpath(file_path, start=dir_to_archive)
|
||||||
|
)
|
||||||
|
zipf.write(file_path, arcname=arcname)
|
||||||
|
|
||||||
|
dduf_file_path = os.path.join(save_directory, dduf_filename)
|
||||||
|
|
||||||
|
if os.path.isdir(dduf_file_path):
|
||||||
|
logger.warning(
|
||||||
|
f"Removing the existing folder {dduf_file_path} so that we can save the DDUF archive."
|
||||||
|
)
|
||||||
|
shutil.rmtree(dduf_file_path)
|
||||||
|
if (
|
||||||
|
os.path.exists(dduf_file_path)
|
||||||
|
and os.path.isfile(dduf_file_path)
|
||||||
|
and zipfile.is_zipfile(dduf_file_path)
|
||||||
|
):
|
||||||
|
# Open in append mode if the file exists
|
||||||
|
mode = "a"
|
||||||
|
else:
|
||||||
|
# Open in write mode to create it if it doesn't exist
|
||||||
|
mode = "w"
|
||||||
|
with zipfile.ZipFile(dduf_file_path, mode=mode, compression=zipfile.ZIP_STORED) as zipf:
|
||||||
|
dir_to_archive = os.path.join(save_directory, pipeline_component_name)
|
||||||
|
if os.path.isdir(dir_to_archive):
|
||||||
|
zipdir(dir_to_archive, zipf)
|
||||||
|
shutil.rmtree(dir_to_archive)
|
||||||
|
|
||||||
# finally save the config
|
# finally save the config
|
||||||
self.save_config(save_directory)
|
self.save_config(save_directory)
|
||||||
|
|
||||||
|
# Takes care of including the "model_index.json" inside the ZIP.
|
||||||
|
# TODO: Include a DDUF a metadata file.
|
||||||
|
if dduf_filename:
|
||||||
|
import zipfile
|
||||||
|
|
||||||
|
with zipfile.ZipFile(dduf_file_path, mode="a", compression=zipfile.ZIP_STORED) as zipf:
|
||||||
|
config_path = os.path.join(save_directory, self.config_name)
|
||||||
|
zipf.write(config_path, arcname=os.path.basename(config_path))
|
||||||
|
os.remove(config_path)
|
||||||
|
|
||||||
if push_to_hub:
|
if push_to_hub:
|
||||||
# Create a new empty model card and eventually tag it
|
# Create a new empty model card and eventually tag it
|
||||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||||
@@ -523,6 +572,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
|
||||||
saved using
|
saved using
|
||||||
[`~DiffusionPipeline.save_pretrained`].
|
[`~DiffusionPipeline.save_pretrained`].
|
||||||
|
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf archive or
|
||||||
|
folder
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
|
||||||
dtype is automatically derived from the model's weights.
|
dtype is automatically derived from the model's weights.
|
||||||
@@ -617,6 +668,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
variant (`str`, *optional*):
|
variant (`str`, *optional*):
|
||||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||||
loading `from_flax`.
|
loading `from_flax`.
|
||||||
|
dduf (`str`, *optional*):
|
||||||
|
Load weights from the specified dduf archive or folder.
|
||||||
|
|
||||||
<Tip>
|
<Tip>
|
||||||
|
|
||||||
@@ -666,6 +719,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
variant = kwargs.pop("variant", None)
|
variant = kwargs.pop("variant", None)
|
||||||
|
dduf = kwargs.pop("dduf", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
use_onnx = kwargs.pop("use_onnx", None)
|
use_onnx = kwargs.pop("use_onnx", None)
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
@@ -736,6 +790,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
custom_pipeline=custom_pipeline,
|
custom_pipeline=custom_pipeline,
|
||||||
custom_revision=custom_revision,
|
custom_revision=custom_revision,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
|
dduf=dduf,
|
||||||
load_connected_pipeline=load_connected_pipeline,
|
load_connected_pipeline=load_connected_pipeline,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@@ -757,7 +812,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
logger.warning(warn_msg)
|
logger.warning(warn_msg)
|
||||||
|
|
||||||
config_dict = cls.load_config(cached_folder)
|
dduf_reader = None
|
||||||
|
if dduf:
|
||||||
|
zip_file_path = os.path.join(cached_folder, dduf)
|
||||||
|
dduf_reader = DDUFReader(zip_file_path)
|
||||||
|
# The reader contains already all the files needed, no need to check it again
|
||||||
|
cached_folder = ""
|
||||||
|
|
||||||
|
config_dict = cls.load_config(cached_folder, dduf_reader=dduf_reader)
|
||||||
|
|
||||||
# pop out "_ignore_files" as it is only needed for download
|
# pop out "_ignore_files" as it is only needed for download
|
||||||
config_dict.pop("_ignore_files", None)
|
config_dict.pop("_ignore_files", None)
|
||||||
@@ -914,6 +976,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
cached_folder=cached_folder,
|
cached_folder=cached_folder,
|
||||||
use_safetensors=use_safetensors,
|
use_safetensors=use_safetensors,
|
||||||
|
dduf_reader=dduf_reader,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||||
@@ -1227,6 +1290,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
variant (`str`, *optional*):
|
variant (`str`, *optional*):
|
||||||
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
|
||||||
loading `from_flax`.
|
loading `from_flax`.
|
||||||
|
dduf(`str`, *optional*):
|
||||||
|
Load weights from the specified DDUF archive or folder.
|
||||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||||
@@ -1267,6 +1332,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
use_onnx = kwargs.pop("use_onnx", None)
|
use_onnx = kwargs.pop("use_onnx", None)
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||||
|
dduf = kwargs.pop("dduf", None)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
@@ -1285,7 +1351,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
local_files_only = True
|
local_files_only = True
|
||||||
model_info_call_error = e # save error to reraise it if model is not cached locally
|
model_info_call_error = e # save error to reraise it if model is not cached locally
|
||||||
|
|
||||||
if not local_files_only:
|
if dduf is not None and not local_files_only:
|
||||||
|
dduf_available = False
|
||||||
|
for sibling in info.siblings:
|
||||||
|
dduf_available = dduf in sibling.rfilename
|
||||||
|
if not dduf_available:
|
||||||
|
raise ValueError(f"Requested {dduf} file is not available in {pretrained_model_name}.")
|
||||||
|
|
||||||
|
if not local_files_only and not dduf:
|
||||||
filenames = {sibling.rfilename for sibling in info.siblings}
|
filenames = {sibling.rfilename for sibling in info.siblings}
|
||||||
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
||||||
warn_msg = (
|
warn_msg = (
|
||||||
@@ -1346,7 +1419,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
|
||||||
# also allow downloading config.json files with the model
|
# also allow downloading config.json files with the model
|
||||||
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
|
||||||
|
|
||||||
allow_patterns += [
|
allow_patterns += [
|
||||||
SCHEDULER_CONFIG_NAME,
|
SCHEDULER_CONFIG_NAME,
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
@@ -1425,10 +1497,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
return snapshot_folder
|
return snapshot_folder
|
||||||
|
|
||||||
user_agent = {"pipeline_class": cls.__name__}
|
user_agent = {"pipeline_class": cls.__name__}
|
||||||
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
if not dduf and custom_pipeline is not None and not custom_pipeline.endswith(".py"):
|
||||||
user_agent["custom_pipeline"] = custom_pipeline
|
user_agent["custom_pipeline"] = custom_pipeline
|
||||||
|
|
||||||
# download all allow_patterns - ignore_patterns
|
# download all allow_patterns - ignore_patterns
|
||||||
|
# also allow downloading the dduf
|
||||||
|
if dduf is not None:
|
||||||
|
allow_patterns = [dduf]
|
||||||
|
ignore_patterns = []
|
||||||
try:
|
try:
|
||||||
cached_folder = snapshot_download(
|
cached_folder = snapshot_download(
|
||||||
pretrained_model_name,
|
pretrained_model_name,
|
||||||
@@ -1443,26 +1519,27 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# retrieve pipeline class from local file
|
# retrieve pipeline class from local file
|
||||||
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
if not dduf:
|
||||||
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
|
||||||
|
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
|
||||||
|
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
|
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
|
||||||
|
|
||||||
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
if pipeline_class is not None and pipeline_class._load_connected_pipes:
|
||||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||||
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
|
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
|
||||||
for connected_pipe_repo_id in connected_pipes:
|
for connected_pipe_repo_id in connected_pipes:
|
||||||
download_kwargs = {
|
download_kwargs = {
|
||||||
"cache_dir": cache_dir,
|
"cache_dir": cache_dir,
|
||||||
"force_download": force_download,
|
"force_download": force_download,
|
||||||
"proxies": proxies,
|
"proxies": proxies,
|
||||||
"local_files_only": local_files_only,
|
"local_files_only": local_files_only,
|
||||||
"token": token,
|
"token": token,
|
||||||
"variant": variant,
|
"variant": variant,
|
||||||
"use_safetensors": use_safetensors,
|
"use_safetensors": use_safetensors,
|
||||||
}
|
}
|
||||||
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
|
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
|
||||||
|
|
||||||
return cached_folder
|
return cached_folder
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from ...utils.torch_utils import randn_tensor
|
|||||||
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
|
||||||
from .modeling_stable_audio import StableAudioProjectionModel
|
from .modeling_stable_audio import StableAudioProjectionModel
|
||||||
|
|
||||||
|
|
||||||
if is_torch_xla_available():
|
if is_torch_xla_available():
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
@@ -732,7 +733,7 @@ class StableAudioPipeline(DiffusionPipeline):
|
|||||||
if callback is not None and i % callback_steps == 0:
|
if callback is not None and i % callback_steps == 0:
|
||||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||||
callback(step_idx, t, latents)
|
callback(step_idx, t, latents)
|
||||||
|
|
||||||
if XLA_AVAILABLE:
|
if XLA_AVAILABLE:
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from .constants import (
|
|||||||
WEIGHTS_INDEX_NAME,
|
WEIGHTS_INDEX_NAME,
|
||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
)
|
)
|
||||||
|
from .dduf import DDUFReader
|
||||||
from .deprecation_utils import deprecate
|
from .deprecation_utils import deprecate
|
||||||
from .doc_utils import replace_example_docstring
|
from .doc_utils import replace_example_docstring
|
||||||
from .dynamic_modules_utils import get_class_from_dynamic_module
|
from .dynamic_modules_utils import get_class_from_dynamic_module
|
||||||
|
|||||||
41
src/diffusers/utils/dduf.py
Normal file
41
src/diffusers/utils/dduf.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
import zipfile
|
||||||
|
|
||||||
|
|
||||||
|
class DDUFReader:
|
||||||
|
def __init__(self, dduf_file):
|
||||||
|
self.dduf_file = dduf_file
|
||||||
|
self.files = []
|
||||||
|
self.post_init()
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
"""
|
||||||
|
Check that the DDUF file is valid
|
||||||
|
"""
|
||||||
|
if not zipfile.is_zipfile(self.dduf_file):
|
||||||
|
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(self.dduf_file, "r") as zf:
|
||||||
|
# Check integrity and store file list
|
||||||
|
zf.testzip() # Returns None if no corrupt files are found
|
||||||
|
self.files = zf.namelist()
|
||||||
|
except zipfile.BadZipFile:
|
||||||
|
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"An error occurred while validating the ZIP file: {e}")
|
||||||
|
|
||||||
|
def has_file(self, file):
|
||||||
|
return file in self.files
|
||||||
|
|
||||||
|
def read_file(self, file_name, encoding=None):
|
||||||
|
"""
|
||||||
|
Reads the content of a specific file in the ZIP archive without extracting.
|
||||||
|
"""
|
||||||
|
if file_name not in self.files:
|
||||||
|
raise ValueError(f"{file_name} is not in the list of files {self.files}")
|
||||||
|
with zipfile.ZipFile(self.dduf_file, "r") as zf:
|
||||||
|
with zf.open(file_name) as file:
|
||||||
|
file = file.read()
|
||||||
|
if encoding is not None:
|
||||||
|
file = file.decode(encoding)
|
||||||
|
return file
|
||||||
@@ -291,9 +291,20 @@ def _get_model_file(
|
|||||||
user_agent: Optional[Union[Dict, str]] = None,
|
user_agent: Optional[Union[Dict, str]] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
commit_hash: Optional[str] = None,
|
commit_hash: Optional[str] = None,
|
||||||
|
dduf_reader=None,
|
||||||
):
|
):
|
||||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||||
if os.path.isfile(pretrained_model_name_or_path):
|
|
||||||
|
if dduf_reader:
|
||||||
|
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||||
|
return os.path.join(pretrained_model_name_or_path, weights_name)
|
||||||
|
elif subfolder is not None and os.path.isfile(
|
||||||
|
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||||
|
):
|
||||||
|
return os.path.join(pretrained_model_name_or_path, weights_name)
|
||||||
|
else:
|
||||||
|
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_reader.files}.")
|
||||||
|
elif os.path.isfile(pretrained_model_name_or_path):
|
||||||
return pretrained_model_name_or_path
|
return pretrained_model_name_or_path
|
||||||
elif os.path.isdir(pretrained_model_name_or_path):
|
elif os.path.isdir(pretrained_model_name_or_path):
|
||||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||||
|
|||||||
Reference in New Issue
Block a user