mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-23 20:15:49 +08:00
Compare commits
8 Commits
cuda-128-g
...
devanshi00
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff26d9ffd5 | ||
|
|
668f265054 | ||
|
|
55eaa6efb2 | ||
|
|
b603429ff5 | ||
|
|
3bc3fdb035 | ||
|
|
8cc38a75d3 | ||
|
|
e5bb10cfe1 | ||
|
|
ec541906c5 |
@@ -1,8 +1,8 @@
|
||||
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,12 +32,10 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
torchaudio
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,12 +32,10 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
torchaudio
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -675,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Union[int, str] = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -707,6 +708,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
|
||||
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -727,12 +731,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||
)
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if push_to_hub:
|
||||
@@ -746,67 +744,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Only save the model itself if we are using distributed training
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.save_config(save_directory)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
if use_flashpack:
|
||||
if not is_main_process:
|
||||
return
|
||||
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
from ..utils.flashpack_utils import save_flashpack
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
save_flashpack(model_to_save, save_directory, variant=variant)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
|
||||
state_dict = model_to_save.state_dict()
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save each shard
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
# Save index file if sharded
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
|
||||
# Push to hub if requested (common to both paths)
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
@@ -939,6 +950,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
||||
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
||||
weights. If set to `False`, `safetensors` weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
|
||||
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
|
||||
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
@@ -982,6 +997,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
@@ -1199,7 +1215,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
|
||||
else:
|
||||
|
||||
flashpack_file = None
|
||||
if use_flashpack:
|
||||
try:
|
||||
flashpack_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant("model.flashpack", variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
except EnvironmentError:
|
||||
flashpack_file = None
|
||||
logger.warning(
|
||||
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
|
||||
)
|
||||
|
||||
if flashpack_file is None:
|
||||
# in the case it is sharded, we have already the index
|
||||
if is_sharded:
|
||||
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
|
||||
@@ -1215,6 +1255,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
logger.warning("Trying to load model weights with safetensors format.")
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
@@ -1280,6 +1321,29 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
if flashpack_file is not None:
|
||||
from ..utils.flashpack_utils import load_flashpack
|
||||
|
||||
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
|
||||
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
|
||||
# explicitly materialize parameters before loading to ensure correctness and parity
|
||||
# with the standard loading path.
|
||||
if any(p.device.type == "meta" for p in model.parameters()):
|
||||
model.to_empty(device="cpu")
|
||||
load_flashpack(model, flashpack_file)
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
return model, {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
state_dict = None
|
||||
if not is_sharded:
|
||||
# Time to load the checkpoint
|
||||
@@ -1327,7 +1391,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1373,6 +1436,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
|
||||
|
||||
return model
|
||||
|
||||
# Adapted from `transformers`.
|
||||
|
||||
@@ -406,7 +406,6 @@ class LongCatImageTransformer2DModel(
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -756,6 +756,7 @@ def load_sub_model(
|
||||
low_cpu_mem_usage: bool,
|
||||
cached_folder: Union[str, os.PathLike],
|
||||
use_safetensors: bool,
|
||||
use_flashpack: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
@@ -838,6 +839,9 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
|
||||
@@ -243,6 +243,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: Optional[str] = None,
|
||||
max_shard_size: Optional[Union[int, str]] = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -268,7 +269,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
|
||||
install flashpack`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@@ -340,6 +343,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
@@ -349,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
@@ -707,6 +713,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
weights. If set to `False`, safetensors weights are not loaded.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
|
||||
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
|
||||
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
|
||||
flashpack`.
|
||||
use_onnx (`bool`, *optional*, defaults to `None`):
|
||||
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
|
||||
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
|
||||
@@ -772,6 +783,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant = kwargs.pop("variant", None)
|
||||
dduf_file = kwargs.pop("dduf_file", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
@@ -1061,6 +1073,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
use_safetensors=use_safetensors,
|
||||
use_flashpack=use_flashpack,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
|
||||
@@ -22,7 +22,6 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -33,9 +32,6 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDIMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -129,10 +125,6 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
||||
@@ -160,10 +152,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self,
|
||||
state: DDIMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
@@ -201,9 +190,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
prediction_type: str = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: Literal["leading", "trailing"] = "leading",
|
||||
timestep_spacing: str = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -210,15 +210,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -264,11 +256,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -320,10 +308,20 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int`):
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||
`use_clipped_model_output` has no effect.
|
||||
variance_noise (`torch.Tensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
|
||||
`tuple`.
|
||||
@@ -337,8 +335,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 1. get previous step value (=t+1)
|
||||
prev_timestep = timestep
|
||||
timestep = min(
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps,
|
||||
self.config.num_train_timesteps - 1,
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
@@ -381,5 +378,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (prev_sample, pred_original_sample)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
81
src/diffusers/utils/flashpack_utils.py
Normal file
81
src/diffusers/utils/flashpack_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from ..utils import _add_variant
|
||||
from .import_utils import is_flashpack_available
|
||||
from .logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_flashpack(
|
||||
model,
|
||||
save_directory: str,
|
||||
variant: Optional[str] = None,
|
||||
is_main_process: bool = True,
|
||||
):
|
||||
"""
|
||||
Save model weights in FlashPack format along with a metadata config.
|
||||
|
||||
Args:
|
||||
model: Diffusers model instance
|
||||
save_directory (`str`): Directory to save weights
|
||||
variant (`str`, *optional*): Model variant
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError(
|
||||
"The `use_flashpack=True` argument requires the `flashpack` package. "
|
||||
"Install it with `pip install flashpack`."
|
||||
)
|
||||
|
||||
from flashpack import pack_to_file
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
weights_name = _add_variant("model.flashpack", variant)
|
||||
weights_path = os.path.join(save_directory, weights_name)
|
||||
config_path = os.path.join(save_directory, "flashpack_config.json")
|
||||
|
||||
try:
|
||||
target_dtype = getattr(model, "dtype", None)
|
||||
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
|
||||
|
||||
# 1. Save binary weights
|
||||
pack_to_file(model, weights_path, target_dtype=target_dtype)
|
||||
|
||||
# 2. Save config metadata (best-effort)
|
||||
if hasattr(model, "config"):
|
||||
try:
|
||||
if hasattr(model.config, "to_dict"):
|
||||
config_data = model.config.to_dict()
|
||||
else:
|
||||
config_data = dict(model.config)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_data, f, indent=4)
|
||||
|
||||
except Exception as config_err:
|
||||
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save weights in FlashPack format: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def load_flashpack(model, flashpack_file: str):
|
||||
"""
|
||||
Assign FlashPack weights from a file into an initialized PyTorch model.
|
||||
"""
|
||||
if not is_flashpack_available():
|
||||
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
|
||||
|
||||
from flashpack import assign_from_file
|
||||
|
||||
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
|
||||
|
||||
try:
|
||||
assign_from_file(model, flashpack_file)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e
|
||||
@@ -231,6 +231,7 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -425,6 +426,10 @@ def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -942,6 +947,16 @@ def is_aiter_version(operation: str, version: str):
|
||||
return compare_versions(parse(_aiter_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_flashpack_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current flashpack version to a given reference with an operation.
|
||||
"""
|
||||
if not _flashpack_available:
|
||||
return False
|
||||
return compare_versions(parse(_flashpack_version), operation, version)
|
||||
|
||||
|
||||
def get_objects_from_module(module):
|
||||
"""
|
||||
Returns a dict of object names and values in a module, while skipping private/internal objects
|
||||
|
||||
Reference in New Issue
Block a user