Compare commits

..

18 Commits

Author SHA1 Message Date
Sayak Paul
dab372dd27 Merge branch 'main' into enable-cp-kernels 2026-01-26 21:58:00 +08:00
Hameer Abbasi
2af7baa040 Remove *pooled_* mentions from Chroma inpaint (#13026)
Remove `*pooled_*` mentions from Chroma as it has just one TE.
2026-01-26 10:18:29 -03:00
David El Malih
a7cb14efbe Improve docstrings and type hints in scheduling_ddpm_parallel.py (#13027)
* docs: improve docstring scheduling_ddpm_parallel.py

* Update scheduling_ddpm_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-25 10:43:43 -08:00
David El Malih
e8e88ff2ce Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)
docs: improve docstring scheduling_ddpm_flax.py
2026-01-23 11:51:47 -08:00
David El Malih
6e24cd842c Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)
* docs: improve docstring scheduling_ddim_parallel.py

* docs: improve docstring scheduling_ddim_parallel.py

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-23 10:00:32 -08:00
Garry Ling
981eb802c6 feat: add qkv projection fuse for longcat transformers (#13021)
feat: add qkv fuse for longcat transformers

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 23:02:03 +05:30
jiqing-feng
1eb40c6dbd Resnet only use contiguous in training mode. (#12977)
* fix contiguous

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* bigger tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 18:40:10 +05:30
Sayak Paul
bff672f47f fix Dockerfiles for cuda and xformers. (#13022) 2026-01-23 16:45:14 +05:30
David El Malih
d4f97d1921 Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)
docs: improve docstring scheduling_ddim_inverse.py
2026-01-22 15:42:45 -08:00
David El Malih
1d32b19ad4 Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)
* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py
2026-01-22 09:11:14 -08:00
Garry Ling
699297f647 feat: accelerate longcat-image with regional compile (#13019) 2026-01-22 20:21:45 +05:30
Sayak Paul
79438572e0 Merge branch 'main' into enable-cp-kernels 2026-01-19 10:28:00 +05:30
sayakpaul
2268583f39 up 2026-01-11 20:05:26 +05:30
Sayak Paul
dfbd4857b2 Merge branch 'main' into enable-cp-kernels 2025-12-17 12:14:40 +08:00
Sayak Paul
9bd83616bf Merge branch 'main' into enable-cp-kernels 2025-12-10 12:33:18 +08:00
sayakpaul
f732ff1144 up 2025-12-09 15:30:33 +05:30
sayakpaul
7a8f85b047 up 2025-12-09 14:59:01 +05:30
sayakpaul
82d20e64a5 up 2025-12-09 14:39:07 +05:30
21 changed files with 605 additions and 369 deletions

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -260,6 +260,10 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -274,7 +278,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -599,22 +607,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1065,6 +1090,231 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1103,6 +1353,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1695,7 +1985,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1713,17 +2003,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -1870,7 +2178,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -1885,33 +2193,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2542,7 +2885,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2570,6 +2913,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -675,7 +675,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Union[int, str] = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -708,9 +707,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether to save the model in [FlashPack](https://github.com/fal-ai/flashpack) format. FlashPack is a
binary format that allows for faster loading. Requires the `flashpack` library to be installed.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -731,6 +727,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
os.makedirs(save_directory, exist_ok=True)
if push_to_hub:
@@ -744,80 +746,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Only save the model itself if we are using distributed training
model_to_save = self
# Attach architecture to the config
# Save the config
if is_main_process:
model_to_save.save_config(save_directory)
if use_flashpack:
if not is_main_process:
return
# Save the model
state_dict = model_to_save.state_dict()
from ..utils.flashpack_utils import save_flashpack
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
save_flashpack(model_to_save, save_directory, variant=variant)
else:
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
state_dict = model_to_save.state_dict()
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
# Save each shard
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Save index file if sharded
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Push to hub if requested (common to both paths)
if push_to_hub:
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(repo_id, token=token)
@@ -950,10 +939,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
weights. If set to `False`, `safetensors` weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` (https://github.com/fal-ai/flashpack)
weights if a compatible `.flashpack` file is found. If flashpack is unavailable or the `.flashpack`
file cannot be used, automatic fallback to the standard loading path (for example, `safetensors`).
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
@@ -997,7 +982,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
@@ -1215,31 +1199,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
model = load_flax_checkpoint_in_pytorch_model(model, resolved_model_file)
flashpack_file = None
if use_flashpack:
try:
flashpack_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant("model.flashpack", variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except EnvironmentError:
flashpack_file = None
logger.warning(
"`use_flashpack` was specified to be True but not flashpack file was found. Resorting to non-flashpack alternatives."
)
if flashpack_file is None:
else:
# in the case it is sharded, we have already the index
if is_sharded:
resolved_model_file, sharded_metadata = _get_checkpoint_shard_files(
@@ -1255,7 +1215,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries=dduf_entries,
)
elif use_safetensors:
logger.warning("Trying to load model weights with safetensors format.")
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
@@ -1321,29 +1280,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if flashpack_file is not None:
from ..utils.flashpack_utils import load_flashpack
# Even when using FlashPack, we preserve `low_cpu_mem_usage` behavior by initializing
# the model with meta tensors. Since FlashPack cannot write into meta tensors, we
# explicitly materialize parameters before loading to ensure correctness and parity
# with the standard loading path.
if any(p.device.type == "meta" for p in model.parameters()):
model.to_empty(device="cpu")
load_flashpack(model, flashpack_file)
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
model.eval()
if output_loading_info:
return model, {
"missing_keys": [],
"unexpected_keys": [],
"mismatched_keys": [],
"error_msgs": [],
}
return model
state_dict = None
if not is_sharded:
# Time to load the checkpoint
@@ -1391,6 +1327,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1436,8 +1373,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if output_loading_info:
return model, loading_info
logger.warning(f"Model till end {pretrained_model_name_or_path} loaded successfully")
return model
# Adapted from `transformers`.

View File

@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
# Issue: https://github.com/huggingface/diffusers/issues/12975
if self.training:
input_tensor = input_tensor.contiguous()
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

View File

@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel(
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Longcat-Image.
"""
_supports_gradient_checkpointing = True
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -482,8 +482,6 @@ class ChromaInpaintPipeline(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
max_sequence_length=None,
@@ -531,15 +529,6 @@ class ChromaInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
@@ -793,13 +782,11 @@ class ChromaInpaintPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,

View File

@@ -756,7 +756,6 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
use_flashpack: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
provider_options: Any,
disable_mmap: bool,
@@ -839,9 +838,6 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True

View File

@@ -243,7 +243,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: Optional[str] = None,
max_shard_size: Optional[Union[int, str]] = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -269,9 +268,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
namespace).
use_flashpack (`bool`, *optional*, defaults to `False`):
Whether or not to use `flashpack` to save the model weights. Requires the `flashpack` library: `pip
install flashpack`.
kwargs (`Dict[str, Any]`, *optional*):
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
@@ -343,7 +340,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
@@ -353,8 +349,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
@@ -713,11 +707,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
weights. If set to `False`, safetensors weights are not loaded.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is first loaded from `flashpack` weights if a compatible `.flashpack` file
is found. If flashpack is unavailable or the `.flashpack` file cannot be used, automatic fallback to
the standard loading path (for example, `safetensors`). Requires the `flashpack` library: `pip install
flashpack`.
use_onnx (`bool`, *optional*, defaults to `None`):
If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights
will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is
@@ -783,7 +772,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant = kwargs.pop("variant", None)
dduf_file = kwargs.pop("dduf_file", None)
use_safetensors = kwargs.pop("use_safetensors", None)
use_flashpack = kwargs.pop("use_flashpack", False)
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
@@ -1073,7 +1061,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
low_cpu_mem_usage=low_cpu_mem_usage,
cached_folder=cached_folder,
use_safetensors=use_safetensors,
use_flashpack=use_flashpack,
dduf_entries=dduf_entries,
provider_options=provider_options,
disable_mmap=disable_mmap,

View File

@@ -22,6 +22,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
prev_timestep >= 0,
state.common.alphas_cumprod[prev_timestep],
state.final_alpha_cumprod,
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
**kwargs,
):
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
`tuple`.
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps,
self.config.num_train_timesteps - 1,
)
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_variance(self, timestep, prev_timestep=None):
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
if prev_timestep is None:
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance
def _batch_get_variance(self, t, prev_t):
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
correction is necessary because the predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
the input and `use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
Random number generator.
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
variance = std_dev_t * variance_noise
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
timesteps (`torch.Tensor`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self) -> int:
return self.config.num_train_timesteps
def previous_timestep(self, timestep: int) -> int:
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
"""
Compute the previous timestep in the diffusion chain.
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDPMSchedulerState:
common: CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
def create(
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:

View File

@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
trained_betas (`np.ndarray`, *optional*):
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`, defaults to `"fixed_small"`):
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
clip_sample (`bool`, defaults to `True`):
Option to clip predicted sample for numerical stability.
prediction_type (`str`, defaults to `"epsilon"`):
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://huggingface.co/papers/2210.02303)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen,
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method (introduced by Imagen,
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, default `0`):
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -293,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
@@ -478,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
"""
@@ -490,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
generator (`torch.Generator`, *optional*):
Random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
Returns:
@@ -503,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
@@ -552,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
if t > 0:
device = model_output.device
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=device,
dtype=model_output.dtype,
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
@@ -575,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
) -> torch.Tensor:
"""
@@ -588,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
timesteps (`torch.Tensor`):
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -603,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
t = t.view(-1, *([1] * (model_output.ndim - 1)))
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
pass
@@ -734,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
"""
Compute the previous timestep in the diffusion chain.
@@ -743,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -1,81 +0,0 @@
import json
import os
from typing import Optional
from ..utils import _add_variant
from .import_utils import is_flashpack_available
from .logging import get_logger
logger = get_logger(__name__)
def save_flashpack(
model,
save_directory: str,
variant: Optional[str] = None,
is_main_process: bool = True,
):
"""
Save model weights in FlashPack format along with a metadata config.
Args:
model: Diffusers model instance
save_directory (`str`): Directory to save weights
variant (`str`, *optional*): Model variant
"""
if not is_flashpack_available():
raise ImportError(
"The `use_flashpack=True` argument requires the `flashpack` package. "
"Install it with `pip install flashpack`."
)
from flashpack import pack_to_file
os.makedirs(save_directory, exist_ok=True)
weights_name = _add_variant("model.flashpack", variant)
weights_path = os.path.join(save_directory, weights_name)
config_path = os.path.join(save_directory, "flashpack_config.json")
try:
target_dtype = getattr(model, "dtype", None)
logger.warning(f"Dtype used for FlashPack save: {target_dtype}")
# 1. Save binary weights
pack_to_file(model, weights_path, target_dtype=target_dtype)
# 2. Save config metadata (best-effort)
if hasattr(model, "config"):
try:
if hasattr(model.config, "to_dict"):
config_data = model.config.to_dict()
else:
config_data = dict(model.config)
with open(config_path, "w") as f:
json.dump(config_data, f, indent=4)
except Exception as config_err:
logger.warning(f"FlashPack weights saved, but config serialization failed: {config_err}")
except Exception as e:
logger.error(f"Failed to save weights in FlashPack format: {e}")
raise
def load_flashpack(model, flashpack_file: str):
"""
Assign FlashPack weights from a file into an initialized PyTorch model.
"""
if not is_flashpack_available():
raise ImportError("FlashPack weights require the `flashpack` package. Install with `pip install flashpack`.")
from flashpack import assign_from_file
logger.warning(f"Loading FlashPack weights from {flashpack_file}")
try:
assign_from_file(model, flashpack_file)
except Exception as e:
raise RuntimeError(f"Failed to load FlashPack weights from {flashpack_file}") from e

View File

@@ -231,7 +231,6 @@ _aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
def is_torch_available():
@@ -426,10 +425,6 @@ def is_av_available():
return _av_available
def is_flashpack_available():
return _flashpack_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -947,16 +942,6 @@ def is_aiter_version(operation: str, version: str):
return compare_versions(parse(_aiter_version), operation, version)
@cache
def is_flashpack_version(operation: str, version: str):
"""
Compares the current flashpack version to a given reference with an operation.
"""
if not _flashpack_available:
return False
return compare_versions(parse(_flashpack_version), operation, version)
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects

View File

@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@is_flaky()
def test_model_cpu_offload_forward_pass(self):
super().test_inference_batch_single_identical(expected_max_diff=8e-4)

View File

@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@slow
@require_torch_accelerator