Compare commits

..

14 Commits

Author SHA1 Message Date
Sayak Paul
dab372dd27 Merge branch 'main' into enable-cp-kernels 2026-01-26 21:58:00 +08:00
Hameer Abbasi
2af7baa040 Remove *pooled_* mentions from Chroma inpaint (#13026)
Remove `*pooled_*` mentions from Chroma as it has just one TE.
2026-01-26 10:18:29 -03:00
David El Malih
a7cb14efbe Improve docstrings and type hints in scheduling_ddpm_parallel.py (#13027)
* docs: improve docstring scheduling_ddpm_parallel.py

* Update scheduling_ddpm_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-25 10:43:43 -08:00
David El Malih
e8e88ff2ce Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)
docs: improve docstring scheduling_ddpm_flax.py
2026-01-23 11:51:47 -08:00
David El Malih
6e24cd842c Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)
* docs: improve docstring scheduling_ddim_parallel.py

* docs: improve docstring scheduling_ddim_parallel.py

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-23 10:00:32 -08:00
Garry Ling
981eb802c6 feat: add qkv projection fuse for longcat transformers (#13021)
feat: add qkv fuse for longcat transformers

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 23:02:03 +05:30
jiqing-feng
1eb40c6dbd Resnet only use contiguous in training mode. (#12977)
* fix contiguous

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* bigger tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 18:40:10 +05:30
Sayak Paul
79438572e0 Merge branch 'main' into enable-cp-kernels 2026-01-19 10:28:00 +05:30
sayakpaul
2268583f39 up 2026-01-11 20:05:26 +05:30
Sayak Paul
dfbd4857b2 Merge branch 'main' into enable-cp-kernels 2025-12-17 12:14:40 +08:00
Sayak Paul
9bd83616bf Merge branch 'main' into enable-cp-kernels 2025-12-10 12:33:18 +08:00
sayakpaul
f732ff1144 up 2025-12-09 15:30:33 +05:30
sayakpaul
7a8f85b047 up 2025-12-09 14:59:01 +05:30
sayakpaul
82d20e64a5 up 2025-12-09 14:39:07 +05:30
14 changed files with 500 additions and 114 deletions

View File

@@ -59,12 +59,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -73,7 +67,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None
_rank: int = None
_world_size: int = None
@@ -122,7 +115,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -260,6 +260,10 @@ class _HubKernelConfig:
function_attr: str
revision: Optional[str] = None
kernel_fn: Optional[Callable] = None
wrapped_forward_attr: Optional[str] = None
wrapped_backward_attr: Optional[str] = None
wrapped_forward_fn: Optional[Callable] = None
wrapped_backward_fn: Optional[Callable] = None
# Registry for hub-based attention kernels
@@ -274,7 +278,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -599,22 +607,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1065,6 +1090,231 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
*,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1103,6 +1353,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1695,7 +1985,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1713,17 +2003,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -1870,7 +2178,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -1885,33 +2193,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2542,7 +2885,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2570,6 +2913,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -1569,7 +1569,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
# Issue: https://github.com/huggingface/diffusers/issues/12975
if self.training:
input_tensor = input_tensor.contiguous()
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

View File

@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -400,6 +400,7 @@ class LongCatImageTransformer2DModel(
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Longcat-Image.

View File

@@ -482,8 +482,6 @@ class ChromaInpaintPipeline(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
max_sequence_length=None,
@@ -531,15 +529,6 @@ class ChromaInpaintPipeline(
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask")
@@ -793,13 +782,11 @@ class ChromaInpaintPipeline(
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,

View File

@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_variance(self, timestep, prev_timestep=None):
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
if prev_timestep is None:
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance
def _batch_get_variance(self, t, prev_t):
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
correction is necessary because the predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
the input and `use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
Random number generator.
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
variance = std_dev_t * variance_noise
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
timesteps (`torch.Tensor`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -281,7 +281,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
@@ -646,7 +646,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self) -> int:
return self.config.num_train_timesteps
def previous_timestep(self, timestep: int) -> int:
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
"""
Compute the previous timestep in the diffusion chain.
@@ -655,7 +655,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDPMSchedulerState:
common: CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
def create(
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:

View File

@@ -149,38 +149,41 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
For more details, see the original paper: https://huggingface.co/papers/2006.11239
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`.
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
trained_betas (`np.ndarray`, *optional*):
Option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
variance_type (`str`, defaults to `"fixed_small"`):
Options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
clip_sample (`bool`, defaults to `True`):
Option to clip predicted sample for numerical stability.
prediction_type (`str`, defaults to `"epsilon"`):
Prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://huggingface.co/papers/2210.02303)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen,
thresholding (`bool`, defaults to `False`):
Whether to use the "dynamic thresholding" method (introduced by Imagen,
https://huggingface.co/papers/2205.11487). Note that the thresholding method is unsuitable for latent-space
diffusion models (such as stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
dynamic_thresholding_ratio (`float`, defaults to 0.995):
The ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://huggingface.co/papers/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, default `"leading"`):
clip_sample_range (`float`, defaults to 1.0):
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
sample_max_value (`float`, defaults to 1.0):
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
timestep_spacing (`str`, defaults to `"leading"`):
The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample
Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
steps_offset (`int`, default `0`):
steps_offset (`int`, defaults to 0):
An offset added to the inference steps, as required by some model families.
rescale_betas_zero_snr (`bool`, defaults to `False`):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
@@ -293,7 +296,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. If used,
`timesteps` must be `None`.
device (`str` or `torch.device`, *optional*):
@@ -478,7 +481,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: int,
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[DDPMParallelSchedulerOutput, Tuple]:
"""
@@ -490,7 +493,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
generator: random number generator.
generator (`torch.Generator`, *optional*):
Random number generator.
return_dict (`bool`): option for returning tuple rather than DDPMParallelSchedulerOutput class
Returns:
@@ -503,7 +507,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
@@ -552,7 +559,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
if t > 0:
device = model_output.device
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=device,
dtype=model_output.dtype,
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
@@ -575,7 +585,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
) -> torch.Tensor:
"""
@@ -588,8 +598,8 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
timesteps (`torch.Tensor`):
Current discrete timesteps in the diffusion chain. This is a tensor of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -603,7 +613,10 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
t = t.view(-1, *([1] * (model_output.ndim - 1)))
prev_t = prev_t.view(-1, *([1] * (model_output.ndim - 1)))
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
"learned",
"learned_range",
]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
pass
@@ -734,7 +747,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
return self.config.num_train_timesteps
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
def previous_timestep(self, timestep):
def previous_timestep(self, timestep: int) -> Union[int, torch.Tensor]:
"""
Compute the previous timestep in the diffusion chain.
@@ -743,7 +756,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -722,7 +722,7 @@ class LCMScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -777,7 +777,7 @@ class TCDScheduler(SchedulerMixin, ConfigMixin):
The current timestep.
Returns:
`int`:
`int` or `torch.Tensor`:
The previous timestep.
"""
if self.custom_timesteps or self.num_inference_steps:

View File

@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@is_flaky()
def test_model_cpu_offload_forward_pass(self):
super().test_inference_batch_single_identical(expected_max_diff=8e-4)

View File

@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@slow
@require_torch_accelerator