Compare commits

...

10 Commits

Author SHA1 Message Date
Marc Sun
1cd5155bb8 remove print 2024-12-04 13:04:48 +00:00
Marc Sun
b14bffeffe first draft 2024-12-04 13:03:35 +00:00
Marc Sun
e66c4d0dab Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-04 13:57:21 +01:00
Marc Sun
7d2c7d5553 Update src/diffusers/pipelines/pipeline_utils.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2024-12-04 13:56:21 +01:00
Sayak Paul
78135f1478 Merge branch 'main' into dduf 2024-12-04 17:34:58 +05:30
sayakpaul
d8408677c5 updates 2024-12-03 14:06:47 +00:00
Sayak Paul
cbee7cbc6b Merge branch 'main' into dduf 2024-11-30 08:56:15 +05:30
Marc Sun
2eeda25321 switch to zip uncompressed 2024-11-28 16:06:04 +01:00
Marc Sun
0389333113 style 2024-11-27 18:01:43 +01:00
Marc Sun
1fb86e34c0 load and save dduf archive 2024-11-27 18:01:36 +01:00
9 changed files with 204 additions and 38 deletions

View File

@@ -347,6 +347,7 @@ class ConfigMixin:
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_reader = kwargs.pop("dduf_reader", None)
user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
@@ -358,8 +359,22 @@ class ConfigMixin:
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)
if os.path.isfile(pretrained_model_name_or_path):
# Custom path for now
if dduf_reader:
if subfolder is not None:
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
config_file = os.path.join(subfolder, cls.config_name)
else:
raise ValueError(
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
)
elif dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, cls.config_name)):
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
else:
raise ValueError(
f"We did not manage to find the file {os.path.join(pretrained_model_name_or_path, cls.config_name)} in the archive. We only have the following files {dduf_reader.files}"
)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
@@ -426,10 +441,8 @@ class ConfigMixin:
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)
try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
config_dict = cls._dict_from_json_file(config_file, dduf_reader=dduf_reader)
commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
@@ -552,9 +565,12 @@ class ConfigMixin:
return init_dict, unused_kwargs, hidden_config_dict
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike], dduf_reader=None):
if dduf_reader:
text = dduf_reader.read_file(json_file, encoding="utf-8")
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def __repr__(self):

View File

@@ -128,7 +128,7 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, dduf_reader=None):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
@@ -138,8 +138,15 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
return checkpoint_file
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if dduf_reader:
checkpoint_file = dduf_reader.read_file(checkpoint_file)
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
if dduf_reader:
# tensors are loaded on cpu
return safetensors.torch.load(checkpoint_file)
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
@@ -272,6 +279,7 @@ def _fetch_index_file(
revision,
user_agent,
commit_hash,
dduf_reader=None,
):
if is_local:
index_file = Path(
@@ -297,6 +305,7 @@ def _fetch_index_file(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_reader=dduf_reader,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):

View File

@@ -557,6 +557,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
dduf_reader = kwargs.pop("dduf_reader", None)
allow_pickle = False
if use_safetensors is None:
@@ -649,6 +650,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
dduf_reader=dduf_reader,
**kwargs,
)
# no in-place modification of the original config.
@@ -724,6 +726,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
"dduf_reader": dduf_reader,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
@@ -759,7 +762,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if is_sharded:
# in the case it is sharded, we have already the index
if is_sharded and not dduf_reader:
sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
@@ -790,6 +794,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_reader=dduf_reader,
)
except IOError as e:
@@ -813,6 +818,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_reader=dduf_reader,
)
if low_cpu_mem_usage:
@@ -837,7 +843,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
param_device = torch.cuda.current_device()
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
model._convert_deprecated_attention_blocks(state_dict)
# move the params from meta device to cpu
@@ -937,7 +943,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, dduf_reader=dduf_reader)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(

View File

@@ -627,6 +627,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
dduf_reader,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
@@ -721,7 +722,10 @@ def load_sub_model(
loading_kwargs["low_cpu_mem_usage"] = False
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
if dduf_reader:
loading_kwargs["dduf_reader"] = dduf_reader
loaded_sub_model = load_method(name, **loading_kwargs)
elif os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory

View File

@@ -50,6 +50,7 @@ from ..utils import (
CONFIG_NAME,
DEPRECATED_REVISION_ARGS,
BaseOutput,
DDUFReader,
PushToHubMixin,
is_accelerate_available,
is_accelerate_version,
@@ -193,6 +194,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
dduf_filename: Optional[Union[str, os.PathLike]] = None,
**kwargs,
):
"""
@@ -301,9 +303,56 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
if dduf_filename:
import shutil
import zipfile
def zipdir(dir_to_archive, zipf):
"""Archive a directory"""
for root, dirs, files in os.walk(dir_to_archive):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.join(
os.path.basename(dir_to_archive), os.path.relpath(file_path, start=dir_to_archive)
)
zipf.write(file_path, arcname=arcname)
dduf_file_path = os.path.join(save_directory, dduf_filename)
if os.path.isdir(dduf_file_path):
logger.warning(
f"Removing the existing folder {dduf_file_path} so that we can save the DDUF archive."
)
shutil.rmtree(dduf_file_path)
if (
os.path.exists(dduf_file_path)
and os.path.isfile(dduf_file_path)
and zipfile.is_zipfile(dduf_file_path)
):
# Open in append mode if the file exists
mode = "a"
else:
# Open in write mode to create it if it doesn't exist
mode = "w"
with zipfile.ZipFile(dduf_file_path, mode=mode, compression=zipfile.ZIP_STORED) as zipf:
dir_to_archive = os.path.join(save_directory, pipeline_component_name)
if os.path.isdir(dir_to_archive):
zipdir(dir_to_archive, zipf)
shutil.rmtree(dir_to_archive)
# finally save the config
self.save_config(save_directory)
# Takes care of including the "model_index.json" inside the ZIP.
# TODO: Include a DDUF a metadata file.
if dduf_filename:
import zipfile
with zipfile.ZipFile(dduf_file_path, mode="a", compression=zipfile.ZIP_STORED) as zipf:
config_path = os.path.join(save_directory, self.config_name)
zipf.write(config_path, arcname=os.path.basename(config_path))
os.remove(config_path)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
@@ -523,6 +572,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights
saved using
[`~DiffusionPipeline.save_pretrained`].
- A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf archive or
folder
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the
dtype is automatically derived from the model's weights.
@@ -617,6 +668,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
dduf (`str`, *optional*):
Load weights from the specified dduf archive or folder.
<Tip>
@@ -666,6 +719,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
offload_state_dict = kwargs.pop("offload_state_dict", False)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
dduf = kwargs.pop("dduf", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
@@ -736,6 +790,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
custom_pipeline=custom_pipeline,
custom_revision=custom_revision,
variant=variant,
dduf=dduf,
load_connected_pipeline=load_connected_pipeline,
**kwargs,
)
@@ -757,7 +812,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
logger.warning(warn_msg)
config_dict = cls.load_config(cached_folder)
dduf_reader = None
if dduf:
zip_file_path = os.path.join(cached_folder, dduf)
dduf_reader = DDUFReader(zip_file_path)
# The reader contains already all the files needed, no need to check it again
cached_folder = ""
config_dict = cls.load_config(cached_folder, dduf_reader=dduf_reader)
# pop out "_ignore_files" as it is only needed for download
config_dict.pop("_ignore_files", None)
@@ -914,6 +976,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
dduf_reader=dduf_reader,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1227,6 +1290,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant (`str`, *optional*):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
dduf(`str`, *optional*):
Load weights from the specified DDUF archive or folder.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
@@ -1267,6 +1332,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
dduf = kwargs.pop("dduf", None)
allow_pickle = False
if use_safetensors is None:
@@ -1285,7 +1351,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
if not local_files_only:
if dduf is not None and not local_files_only:
dduf_available = False
for sibling in info.siblings:
dduf_available = dduf in sibling.rfilename
if not dduf_available:
raise ValueError(f"Requested {dduf} file is not available in {pretrained_model_name}.")
if not local_files_only and not dduf:
filenames = {sibling.rfilename for sibling in info.siblings}
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
warn_msg = (
@@ -1346,7 +1419,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else []
# also allow downloading config.json files with the model
allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names]
allow_patterns += [
SCHEDULER_CONFIG_NAME,
CONFIG_NAME,
@@ -1425,10 +1497,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return snapshot_folder
user_agent = {"pipeline_class": cls.__name__}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
if not dduf and custom_pipeline is not None and not custom_pipeline.endswith(".py"):
user_agent["custom_pipeline"] = custom_pipeline
# download all allow_patterns - ignore_patterns
# also allow downloading the dduf
if dduf is not None:
allow_patterns = [dduf]
ignore_patterns = []
try:
cached_folder = snapshot_download(
pretrained_model_name,
@@ -1443,26 +1519,27 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
)
# retrieve pipeline class from local file
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
if not dduf:
cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None)
cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
diffusers_module = importlib.import_module(__name__.split(".")[0])
pipeline_class = getattr(diffusers_module, cls_name, None) if isinstance(cls_name, str) else None
if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
for connected_pipe_repo_id in connected_pipes:
download_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"variant": variant,
"use_safetensors": use_safetensors,
}
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
if pipeline_class is not None and pipeline_class._load_connected_pipes:
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], [])
for connected_pipe_repo_id in connected_pipes:
download_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"local_files_only": local_files_only,
"token": token,
"variant": variant,
"use_safetensors": use_safetensors,
}
DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs)
return cached_folder

View File

@@ -34,6 +34,7 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline
from .modeling_stable_audio import StableAudioProjectionModel
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -732,7 +733,7 @@ class StableAudioPipeline(DiffusionPipeline):
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()

View File

@@ -35,6 +35,7 @@ from .constants import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from .dduf import DDUFReader
from .deprecation_utils import deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module

View File

@@ -0,0 +1,41 @@
import zipfile
class DDUFReader:
def __init__(self, dduf_file):
self.dduf_file = dduf_file
self.files = []
self.post_init()
def post_init(self):
"""
Check that the DDUF file is valid
"""
if not zipfile.is_zipfile(self.dduf_file):
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
try:
with zipfile.ZipFile(self.dduf_file, "r") as zf:
# Check integrity and store file list
zf.testzip() # Returns None if no corrupt files are found
self.files = zf.namelist()
except zipfile.BadZipFile:
raise ValueError(f"The file '{self.dduf_file}' is not a valid ZIP archive.")
except Exception as e:
raise RuntimeError(f"An error occurred while validating the ZIP file: {e}")
def has_file(self, file):
return file in self.files
def read_file(self, file_name, encoding=None):
"""
Reads the content of a specific file in the ZIP archive without extracting.
"""
if file_name not in self.files:
raise ValueError(f"{file_name} is not in the list of files {self.files}")
with zipfile.ZipFile(self.dduf_file, "r") as zf:
with zf.open(file_name) as file:
file = file.read()
if encoding is not None:
file = file.decode(encoding)
return file

View File

@@ -291,9 +291,20 @@ def _get_model_file(
user_agent: Optional[Union[Dict, str]] = None,
revision: Optional[str] = None,
commit_hash: Optional[str] = None,
dduf_reader=None,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
if dduf_reader:
if dduf_reader.has_file(os.path.join(pretrained_model_name_or_path, weights_name)):
return os.path.join(pretrained_model_name_or_path, weights_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
return os.path.join(pretrained_model_name_or_path, weights_name)
else:
raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_reader.files}.")
elif os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):