Compare commits

..

5 Commits

Author SHA1 Message Date
DN6
cfff46069d add custom mesh support 2026-02-02 13:12:09 +05:30
Sayak Paul
bff672f47f fix Dockerfiles for cuda and xformers. (#13022) 2026-01-23 16:45:14 +05:30
David El Malih
d4f97d1921 Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)
docs: improve docstring scheduling_ddim_inverse.py
2026-01-22 15:42:45 -08:00
David El Malih
1d32b19ad4 Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)
* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py
2026-01-22 09:11:14 -08:00
Garry Ling
699297f647 feat: accelerate longcat-image with regional compile (#13019) 2026-01-22 20:21:45 +05:30
11 changed files with 116 additions and 266 deletions

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -59,6 +59,12 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -67,6 +73,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None
_rank: int = None
_world_size: int = None
@@ -115,7 +122,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -675,7 +675,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -708,9 +707,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -731,6 +727,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
@@ -744,80 +746,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Only save the model itself if we are using distributed training
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
if use_flashpack:
if not is_main_process:
return
# Save the model
state_dict = model_to_save.state_dict()
from ..utils.flashpack_utils import save_flashpack
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
@@ -950,10 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -997,7 +982,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1215,31 +1199,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
flashpack_file = None
if use_flashpack:
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None
logger.warning(
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
)
if flashpack_file is None:
else:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
@@ -1255,7 +1215,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
@@ -1321,29 +1280,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model
state_dict = None
if not is_sharded:
# Time to load the checkpoint
@@ -1391,6 +1327,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1436,8 +1373,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if output_loading_info:
return model, loading_info
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
return model
# Adapted from `transformers`.
@@ -1634,7 +1569,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -406,6 +406,7 @@ class LongCatImageTransformer2DModel(
"""
_supports_gradient_checkpointing = True
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -756,7 +756,6 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
@@ -839,9 +838,6 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True

View File

@@ -243,7 +243,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -269,9 +268,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -343,7 +340,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
@@ -353,8 +349,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
@@ -713,11 +707,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -783,7 +772,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
@@ -1073,7 +1061,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,

View File

@@ -22,6 +22,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
prev_timestep >= 0,
state.common.alphas_cumprod[prev_timestep],
state.final_alpha_cumprod,
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
**kwargs,
):
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
`tuple`.
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps,
self.config.num_train_timesteps - 1,
)
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -1,81 +0,0 @@
import json
import os
from typing import Optional
from ..utils import _add_variant
from .import_utils import is_flashpack_available
from .logging import get_logger
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.
Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)
from flashpack import pack_to_file
os.makedirs(save_directory, exist_ok=True)
weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)
# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)
with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e

View File

@@ -231,7 +231,6 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
@@ -426,10 +425,6 @@ def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -947,16 +942,6 @@ def is_aiter_version(operation: str, version: str):
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
Compares the current flashpack version to a given reference with an operation.
"""
if not _flashpack_available:
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects