Compare commits

..

10 Commits

Author SHA1 Message Date
yiyixuxu
1f57b175ae style 2026-01-24 03:50:01 +01:00
yiyixuxu
581a425130 tag loader_id from Automodel 2026-01-24 03:49:29 +01:00
David El Malih
e8e88ff2ce Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)
docs: improve docstring scheduling_ddpm_flax.py
2026-01-23 11:51:47 -08:00
David El Malih
6e24cd842c Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)
* docs: improve docstring scheduling_ddim_parallel.py

* docs: improve docstring scheduling_ddim_parallel.py

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-23 10:00:32 -08:00
Garry Ling
981eb802c6 feat: add qkv projection fuse for longcat transformers (#13021)
feat: add qkv fuse for longcat transformers

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 23:02:03 +05:30
jiqing-feng
1eb40c6dbd Resnet only use contiguous in training mode. (#12977)
* fix contiguous

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* bigger tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 18:40:10 +05:30
Sayak Paul
bff672f47f fix Dockerfiles for cuda and xformers. (#13022) 2026-01-23 16:45:14 +05:30
David El Malih
d4f97d1921 Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)
docs: improve docstring scheduling_ddim_inverse.py
2026-01-22 15:42:45 -08:00
David El Malih
1d32b19ad4 Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)
* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py
2026-01-22 09:11:14 -08:00
Garry Ling
699297f647 feat: accelerate longcat-image with regional compile (#13019) 2026-01-22 20:21:45 +05:30
19 changed files with 180 additions and 288 deletions

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -18,7 +18,7 @@ from typing import Optional, Union
from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
kwargs = {**load_config_kwargs, **kwargs}
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
load_id = "|".join("null" if p is None else p for p in parts)
model._diffusers_load_id = load_id
return model

View File

@@ -675,7 +675,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -708,9 +707,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -731,6 +727,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
@@ -744,80 +746,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Only save the model itself if we are using distributed training
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
if use_flashpack:
if not is_main_process:
return
# Save the model
state_dict = model_to_save.state_dict()
from ..utils.flashpack_utils import save_flashpack
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
@@ -950,10 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -997,7 +982,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1215,31 +1199,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
flashpack_file = None
if use_flashpack:
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None
logger.warning(
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
)
if flashpack_file is None:
else:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
@@ -1255,7 +1215,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
@@ -1321,29 +1280,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model
state_dict = None
if not is_sharded:
# Time to load the checkpoint
@@ -1391,6 +1327,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1436,8 +1373,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if output_loading_info:
return model, loading_info
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
return model
# Adapted from `transformers`.

View File

@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
# Issue: https://github.com/huggingface/diffusers/issues/12975
if self.training:
input_tensor = input_tensor.contiguous()
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

View File

@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel(
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Longcat-Image.
"""
_supports_gradient_checkpointing = True
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -15,14 +15,14 @@
import inspect
import re
from collections import OrderedDict
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import is_torch_available, logging
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
if is_torch_available():
@@ -185,7 +185,7 @@ class ComponentSpec:
"""
Return the names of all loadingrelated fields (i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
return DIFFUSERS_LOAD_ID_FIELDS.copy()
@property
def load_id(self) -> str:
@@ -197,7 +197,7 @@ class ComponentSpec:
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
return "|".join(parts)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:

View File

@@ -756,7 +756,6 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
@@ -839,9 +838,6 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True

View File

@@ -243,7 +243,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -269,9 +268,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -343,7 +340,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
@@ -353,8 +349,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
@@ -713,11 +707,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -783,7 +772,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
@@ -1073,7 +1061,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,

View File

@@ -22,6 +22,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
prev_timestep >= 0,
state.common.alphas_cumprod[prev_timestep],
state.final_alpha_cumprod,
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
**kwargs,
):
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
`tuple`.
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps,
self.config.num_train_timesteps - 1,
)
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_variance(self, timestep, prev_timestep=None):
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
if prev_timestep is None:
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance
def _batch_get_variance(self, t, prev_t):
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
correction is necessary because the predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
the input and `use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
Random number generator.
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
variance = std_dev_t * variance_noise
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
timesteps (`torch.Tensor`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDPMSchedulerState:
common: CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
def create(
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:

View File

@@ -23,6 +23,7 @@ from .constants import (
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,

View File

@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
DIFFUSERS_LOAD_ID_FIELDS = [
"pretrained_model_name_or_path",
"subfolder",
"variant",
"revision",
]

View File

@@ -1,81 +0,0 @@
import json
import os
from typing import Optional
from ..utils import _add_variant
from .import_utils import is_flashpack_available
from .logging import get_logger
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.
Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)
from flashpack import pack_to_file
os.makedirs(save_directory, exist_ok=True)
weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)
# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)
with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e

View File

@@ -231,7 +231,6 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
@@ -426,10 +425,6 @@ def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -947,16 +942,6 @@ def is_aiter_version(operation: str, version: str):
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
Compares the current flashpack version to a given reference with an operation.
"""
if not _flashpack_available:
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects

View File

@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@is_flaky()
def test_model_cpu_offload_forward_pass(self):
super().test_inference_batch_single_identical(expected_max_diff=8e-4)

View File

@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@slow
@require_torch_accelerator