Compare commits

...

11 Commits

Author SHA1 Message Date
sayakpaul
67f4691cab resolve conflicts. 2026-02-19 18:22:49 +05:30
Sayak Paul
99daaa802d [core] Enable CP for kernels-based attention backends (#12812)
* up

* up

* up

* up

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-02-19 18:16:50 +05:30
sayakpaul
e10fe61303 fix version and force updated kernels. 2026-02-19 18:00:01 +05:30
Sayak Paul
348350cf24 Merge branch 'main' into update-kernel-hub-repos 2026-02-19 17:53:46 +05:30
dg845
fe78a7b7c6 Fix ftfy import for PRX Pipeline (#13154)
* Guard ftfy import with is_ftfy_available

* Remove xfail for PRX pipeline tests as they appear to work on transformers>4.57.1

* make style and make quality
2026-02-18 20:44:33 -08:00
dg845
53e1d0e458 [CI] Revert setuptools CI Fix as the Failing Pipelines are Deprecated (#13149)
* Pin setuptools version for dependencies which explicitly depend on pkg_resources

* Revert setuptools pin as k-diffusion pipelines are now deprecated

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-18 20:34:00 -08:00
Sayak Paul
af35e3806c Merge branch 'main' into update-kernel-hub-repos 2026-02-19 09:35:15 +05:30
dxqb
a577ec36df Flux2: Tensor tuples can cause issues for checkpointing (#12777)
* split tensors inside the transformer blocks to avoid checkpointing issues

* clean up, fix type hints

* fix merge error

* Apply style fixes

---------

Co-authored-by: s <you@example.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-18 17:03:22 -08:00
Steven Liu
6875490c3b [docs] add docs for qwenimagelayered (#13158)
* add example

* feedback
2026-02-18 11:02:15 -08:00
David El Malih
64734b2115 docs: improve docstring scheduling_flow_match_lcm.py (#13160)
Improve docstring scheduling flow match lcm
2026-02-18 10:52:02 -08:00
sayakpaul
d6bc647932 change to updated repo and version. 2026-02-18 23:46:06 +05:30
10 changed files with 549 additions and 126 deletions

View File

@@ -199,11 +199,6 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git

View File

@@ -126,11 +126,6 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git

View File

@@ -29,7 +29,7 @@ Qwen-Image comes in the following variants:
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.
## LoRA for faster inference
@@ -190,6 +190,12 @@ For detailed benchmark scripts and results, see [this gist](https://gist.github.
- all
- __call__
## QwenImageLayeredPipeline
[[autodoc]] QwenImageLayeredPipeline
- all
- __call__
## QwenImagePipelineOutput
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

View File

@@ -38,6 +38,7 @@ from ..utils import (
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -265,28 +266,41 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
wrapped_forward_fn: Callable | None = None
wrapped_backward_fn: Callable | None = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
# revision="fake-ops-return-probs",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
version=1,
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
),
}
@@ -456,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
@@ -605,22 +623,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1071,6 +1106,237 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1109,6 +1375,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1965,7 +2271,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1983,17 +2289,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -2140,7 +2464,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2155,33 +2479,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2813,7 +3172,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2841,6 +3200,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -424,7 +424,7 @@ class Flux2SingleTransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
temb_mod: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
split_hidden_states: bool = False,
@@ -436,7 +436,7 @@ class Flux2SingleTransformerBlock(nn.Module):
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
mod_shift, mod_scale, mod_gate = temb_mod_params
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
@@ -498,16 +498,18 @@ class Flux2TransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
temb_mod_img: torch.Tensor,
temb_mod_txt: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
# Modulation parameters shape: [1, 1, self.dim]
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
temb_mod_txt, 2
)
# Img stream
norm_hidden_states = self.norm1(hidden_states)
@@ -627,15 +629,19 @@ class Flux2Modulation(nn.Module):
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
self.act_fn = nn.SiLU()
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
def forward(self, temb: torch.Tensor) -> torch.Tensor:
mod = self.act_fn(temb)
mod = self.linear(mod)
return mod
@staticmethod
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
if mod.ndim == 2:
mod = mod.unsqueeze(1)
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
# Return tuple of 3-tuples of modulation params shift/scale/gate
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
class Flux2Transformer2DModel(
@@ -824,7 +830,7 @@ class Flux2Transformer2DModel(
double_stream_mod_img = self.double_stream_modulation_img(temb)
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)[0]
single_stream_mod = self.single_stream_modulation(temb)
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -861,8 +867,8 @@ class Flux2Transformer2DModel(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
temb_mod_img=double_stream_mod_img,
temb_mod_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
@@ -884,7 +890,7 @@ class Flux2Transformer2DModel(
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=None,
temb_mod_params=single_stream_mod,
temb_mod=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

View File

@@ -18,7 +18,6 @@ import re
import urllib.parse as ul
from typing import Callable
import ftfy
import torch
from transformers import (
AutoTokenizer,
@@ -34,13 +33,13 @@ from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils import is_ftfy_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
if is_ftfy_available():
import ftfy
DEFAULT_RESOLUTION = 512
ASPECT_RATIO_256_BIN = {

View File

@@ -14,6 +14,7 @@
import math
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
@@ -41,7 +42,7 @@ class FlowMatchLCMSchedulerOutput(BaseOutput):
denoising loop.
"""
prev_sample: torch.FloatTensor
prev_sample: torch.Tensor
class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
@@ -79,11 +80,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, defaults to False):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
scale_factors ('list', defaults to None)
The type of dynamic resolution-dependent timestep shifting to apply.
scale_factors (`list[float]`, *optional*, defaults to `None`):
It defines how to scale the latents at which predictions are made.
upscale_mode ('str', defaults to 'bicubic')
Upscaling method, applied if scale-wise generation is considered
upscale_mode (`str`, *optional*, defaults to "bicubic"):
Upscaling method, applied if scale-wise generation is considered.
"""
_compatibles = []
@@ -101,16 +102,33 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
max_image_seq_len: int = 4096,
invert_sigmas: bool = False,
shift_terminal: float | None = None,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
time_shift_type: str = "exponential",
use_karras_sigmas: bool | None = False,
use_exponential_sigmas: bool | None = False,
use_beta_sigmas: bool | None = False,
time_shift_type: Literal["exponential", "linear"] = "exponential",
scale_factors: list[float] | None = None,
upscale_mode: str = "bicubic",
upscale_mode: Literal[
"nearest",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
"nearest-exact",
] = "bicubic",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -162,7 +180,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -172,18 +190,18 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_shift(self, shift: float):
def set_shift(self, shift: float) -> None:
self._shift = shift
def set_scale_factors(self, scale_factors: list, upscale_mode):
def set_scale_factors(self, scale_factors: list[float], upscale_mode: str) -> None:
"""
Sets scale factors for a scale-wise generation regime.
Args:
scale_factors (`list`):
The scale factors for each step
scale_factors (`list[float]`):
The scale factors for each step.
upscale_mode (`str`):
Upscaling method
Upscaling method.
"""
self._scale_factors = scale_factors
self._upscale_mode = upscale_mode
@@ -238,16 +256,18 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return sample
def _sigma_to_t(self, sigma):
def _sigma_to_t(self, sigma: float | torch.FloatTensor) -> float | torch.FloatTensor:
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
def time_shift(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
def stretch_shift_to_terminal(self, t: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
@@ -256,12 +276,13 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
t (`torch.Tensor` or `np.ndarray`):
A tensor or numpy array of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
`torch.Tensor` or `np.ndarray`:
A tensor or numpy array of adjusted timesteps such that the final value equals
`self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
@@ -270,12 +291,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int = None,
device: str | torch.device = None,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
sigmas: list[float] | None = None,
mu: float = None,
mu: float | None = None,
timesteps: list[float] | None = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -317,43 +338,45 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
is_timesteps_provided = timesteps is not None
if is_timesteps_provided:
timesteps = np.array(timesteps).astype(np.float32)
timesteps = np.array(timesteps).astype(np.float32) # type: ignore
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
timesteps = np.linspace( # type: ignore
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = timesteps / self.config.num_train_timesteps # type: ignore
else:
sigmas = np.array(sigmas).astype(np.float32)
sigmas = np.array(sigmas).astype(np.float32) # type: ignore
num_inference_steps = len(sigmas)
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
# "exponential" or "linear" type is applied
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
sigmas = self.time_shift(mu, 1.0, sigmas) # type: ignore
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) # type: ignore
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
sigmas = self.stretch_shift_to_terminal(sigmas) # type: ignore
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
# 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # type: ignore
if not is_timesteps_provided:
timesteps = sigmas * self.config.num_train_timesteps
timesteps = sigmas * self.config.num_train_timesteps # type: ignore
else:
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) # type: ignore
# 6. Append the terminal sigma value.
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
@@ -370,7 +393,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self,
timestep: float | torch.Tensor,
schedule_timesteps: torch.Tensor | None = None,
) -> int:
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -382,9 +409,9 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
return int(indices[pos].item())
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: float | torch.Tensor) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -459,7 +486,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
size = [round(self._scale_factors[self._step_index] * size) for size in self._init_size]
x0_pred = torch.nn.functional.interpolate(x0_pred, size=size, mode=self._upscale_mode)
noise = randn_tensor(x0_pred.shape, generator=generator, device=x0_pred.device, dtype=x0_pred.dtype)
noise = randn_tensor(
x0_pred.shape,
generator=generator,
device=x0_pred.device,
dtype=x0_pred.dtype,
)
prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noise
# upon completion increase step index by one
@@ -473,7 +505,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchLCMSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -594,11 +626,15 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
)
return sigmas
def _time_shift_exponential(self, mu, sigma, t):
def _time_shift_exponential(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def _time_shift_linear(self, mu, sigma, t):
def _time_shift_linear(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -86,6 +86,7 @@ from .import_utils import (
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,

View File

@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""

View File

@@ -1,7 +1,6 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
@@ -11,17 +10,11 @@ from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}