Compare commits

..

13 Commits

Author SHA1 Message Date
sayakpaul
67f4691cab resolve conflicts. 2026-02-19 18:22:49 +05:30
Sayak Paul
99daaa802d [core] Enable CP for kernels-based attention backends (#12812)
* up

* up

* up

* up

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-02-19 18:16:50 +05:30
sayakpaul
e10fe61303 fix version and force updated kernels. 2026-02-19 18:00:01 +05:30
Sayak Paul
348350cf24 Merge branch 'main' into update-kernel-hub-repos 2026-02-19 17:53:46 +05:30
dg845
fe78a7b7c6 Fix ftfy import for PRX Pipeline (#13154)
* Guard ftfy import with is_ftfy_available

* Remove xfail for PRX pipeline tests as they appear to work on transformers>4.57.1

* make style and make quality
2026-02-18 20:44:33 -08:00
dg845
53e1d0e458 [CI] Revert setuptools CI Fix as the Failing Pipelines are Deprecated (#13149)
* Pin setuptools version for dependencies which explicitly depend on pkg_resources

* Revert setuptools pin as k-diffusion pipelines are now deprecated

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-18 20:34:00 -08:00
Sayak Paul
af35e3806c Merge branch 'main' into update-kernel-hub-repos 2026-02-19 09:35:15 +05:30
dxqb
a577ec36df Flux2: Tensor tuples can cause issues for checkpointing (#12777)
* split tensors inside the transformer blocks to avoid checkpointing issues

* clean up, fix type hints

* fix merge error

* Apply style fixes

---------

Co-authored-by: s <you@example.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-18 17:03:22 -08:00
Steven Liu
6875490c3b [docs] add docs for qwenimagelayered (#13158)
* add example

* feedback
2026-02-18 11:02:15 -08:00
David El Malih
64734b2115 docs: improve docstring scheduling_flow_match_lcm.py (#13160)
Improve docstring scheduling flow match lcm
2026-02-18 10:52:02 -08:00
sayakpaul
d6bc647932 change to updated repo and version. 2026-02-18 23:46:06 +05:30
Dhruv Nair
f81e653197 [CI] Add ftfy as a test dependency (#13155)
* update

* update

* update

* update

* update

* update
2026-02-18 22:51:10 +05:30
zhangtao0408
bcbbded7c3 [Bug] Fix QwenImageEditPlus Series on NPU (#13017)
* [Bug Fix][Qwen-Image-Edit] Fix Qwen-Image-Edit series on NPU

* Enhance NPU attention handling by converting attention mask to boolean and refining mask checks.

* Refine attention mask handling in NPU attention function to improve validation and conversion logic.

* Clean Code

* Refine attention mask processing in NPU attention functions to enhance performance and validation.

* Remove item() ops on npu fa backend.

* Reuse NPU attention mask by `_maybe_modify_attn_mask_npu`

* Apply style fixes

* Update src/diffusers/models/attention_dispatch.py

---------

Co-authored-by: zhangtao <zhangtao529@huawei.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-02-17 09:10:40 +05:30
25 changed files with 1479 additions and 1045 deletions

View File

@@ -117,7 +117,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip install -e ".[quality,test]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps

View File

@@ -114,7 +114,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip install -e ".[quality,test]"
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
@@ -191,7 +191,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip install -e ".[quality,test]"
- name: Environment
run: |
@@ -242,7 +242,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip install -e ".[quality,test]"
# TODO (sayakpaul, DN6): revisit `--no-deps`
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers

View File

@@ -199,11 +199,6 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git

View File

@@ -126,11 +126,6 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git

View File

@@ -29,7 +29,7 @@ Qwen-Image comes in the following variants:
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
> [!TIP]
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.
## LoRA for faster inference
@@ -190,6 +190,12 @@ For detailed benchmark scripts and results, see [this gist](https://gist.github.
- all
- __call__
## QwenImageLayeredPipeline
[[autodoc]] QwenImageLayeredPipeline
- all
- __call__
## QwenImagePipelineOutput
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

View File

@@ -101,6 +101,7 @@ _deps = [
"datasets",
"filelock",
"flax>=0.4.1",
"ftfy",
"hf-doc-builder>=0.3.0",
"httpx<1.0.0",
"huggingface-hub>=0.34.0,<2.0",
@@ -221,12 +222,14 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
extras["test"] = deps_list(
"compel",
"ftfy",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
"librosa",
"parameterized",
"protobuf",
"pytest",
"pytest-timeout",
"pytest-xdist",
@@ -235,6 +238,7 @@ extras["test"] = deps_list(
"sentencepiece",
"scipy",
"tiktoken",
"torchsde",
"torchvision",
"transformers",
"phonemizer",

View File

@@ -8,6 +8,7 @@ deps = {
"datasets": "datasets",
"filelock": "filelock",
"flax": "flax>=0.4.1",
"ftfy": "ftfy",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"httpx": "httpx<1.0.0",
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",

View File

@@ -38,6 +38,7 @@ from ..utils import (
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_kernels_version,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -265,28 +266,41 @@ class _HubKernelConfig:
repo_id: str
function_attr: str
revision: str | None = None
version: int | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
wrapped_forward_fn: Callable | None = None
wrapped_backward_fn: Callable | None = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_varlen_func",
# revision="fake-ops-return-probs",
version=1,
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
version=1,
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_varlen_func",
version=1,
),
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
repo_id="kernels-community/sage-attention",
function_attr="sageattn",
version=1,
),
}
@@ -456,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
raise RuntimeError(
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
if not is_kernels_version(">=", "0.12"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
@@ -605,22 +623,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
if config.kernel_fn is not None:
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -1071,6 +1106,237 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=return_lse,
)
lse = None
if return_lse:
out, lse = out
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value)
ctx.scale = scale
ctx.is_causal = is_causal
ctx._hub_kernel = func
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
query, key, value = ctx.saved_tensors
kernel_fn = ctx._hub_kernel
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
with torch.enable_grad():
query_r = query.detach().requires_grad_(True)
key_r = key.detach().requires_grad_(True)
value_r = value.detach().requires_grad_(True)
out = kernel_fn(
q=query_r,
k=key_r,
v=value_r,
softmax_scale=ctx.scale,
causal=ctx.is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=False,
)
if isinstance(out, tuple):
out = out[0]
grad_query, grad_key, grad_value = torch.autograd.grad(
out,
(query_r, key_r, value_r),
grad_out,
retain_graph=False,
allow_unused=False,
)
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1109,6 +1375,46 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1117,6 +1423,26 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if attn_mask is not None and torch.all(attn_mask != 0):
attn_mask = None
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
return attn_mask
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1134,11 +1460,14 @@ def _npu_attention_forward_op(
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -1942,7 +2271,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -1960,17 +2289,35 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@@ -2117,7 +2464,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2132,33 +2479,68 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
@_AttentionBackendRegistry.register(
@@ -2668,16 +3050,17 @@ def _native_npu_attention(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2692,7 +3075,7 @@ def _native_npu_attention(
query,
key,
value,
None,
attn_mask,
dropout_p,
None,
scale,
@@ -2789,7 +3172,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
supports_context_parallel=True,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -2817,6 +3200,23 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -424,7 +424,7 @@ class Flux2SingleTransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
temb_mod: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
split_hidden_states: bool = False,
@@ -436,7 +436,7 @@ class Flux2SingleTransformerBlock(nn.Module):
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
mod_shift, mod_scale, mod_gate = temb_mod_params
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
@@ -498,16 +498,18 @@ class Flux2TransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
temb_mod_img: torch.Tensor,
temb_mod_txt: torch.Tensor,
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
# Modulation parameters shape: [1, 1, self.dim]
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
temb_mod_txt, 2
)
# Img stream
norm_hidden_states = self.norm1(hidden_states)
@@ -627,15 +629,19 @@ class Flux2Modulation(nn.Module):
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
self.act_fn = nn.SiLU()
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
def forward(self, temb: torch.Tensor) -> torch.Tensor:
mod = self.act_fn(temb)
mod = self.linear(mod)
return mod
@staticmethod
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
if mod.ndim == 2:
mod = mod.unsqueeze(1)
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
# Return tuple of 3-tuples of modulation params shift/scale/gate
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
class Flux2Transformer2DModel(
@@ -824,7 +830,7 @@ class Flux2Transformer2DModel(
double_stream_mod_img = self.double_stream_modulation_img(temb)
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)[0]
single_stream_mod = self.single_stream_modulation(temb)
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -861,8 +867,8 @@ class Flux2Transformer2DModel(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
temb_mod_img=double_stream_mod_img,
temb_mod_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
@@ -884,7 +890,7 @@ class Flux2Transformer2DModel(
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=None,
temb_mod_params=single_stream_mod,
temb_mod=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

View File

@@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask(
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
)
return text_seq_len, per_sample_len, encoder_hidden_states_mask

View File

@@ -18,7 +18,6 @@ import re
import urllib.parse as ul
from typing import Callable
import ftfy
import torch
from transformers import (
AutoTokenizer,
@@ -34,13 +33,13 @@ from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils import is_ftfy_available, logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
if is_ftfy_available():
import ftfy
DEFAULT_RESOLUTION = 512
ASPECT_RATIO_256_BIN = {

View File

@@ -14,6 +14,7 @@
import math
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
@@ -41,7 +42,7 @@ class FlowMatchLCMSchedulerOutput(BaseOutput):
denoising loop.
"""
prev_sample: torch.FloatTensor
prev_sample: torch.Tensor
class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
@@ -79,11 +80,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, defaults to False):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
scale_factors ('list', defaults to None)
The type of dynamic resolution-dependent timestep shifting to apply.
scale_factors (`list[float]`, *optional*, defaults to `None`):
It defines how to scale the latents at which predictions are made.
upscale_mode ('str', defaults to 'bicubic')
Upscaling method, applied if scale-wise generation is considered
upscale_mode (`str`, *optional*, defaults to "bicubic"):
Upscaling method, applied if scale-wise generation is considered.
"""
_compatibles = []
@@ -101,16 +102,33 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
max_image_seq_len: int = 4096,
invert_sigmas: bool = False,
shift_terminal: float | None = None,
use_karras_sigmas: bool = False,
use_exponential_sigmas: bool = False,
use_beta_sigmas: bool = False,
time_shift_type: str = "exponential",
use_karras_sigmas: bool | None = False,
use_exponential_sigmas: bool | None = False,
use_beta_sigmas: bool | None = False,
time_shift_type: Literal["exponential", "linear"] = "exponential",
scale_factors: list[float] | None = None,
upscale_mode: str = "bicubic",
upscale_mode: Literal[
"nearest",
"linear",
"bilinear",
"bicubic",
"trilinear",
"area",
"nearest-exact",
] = "bicubic",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
@@ -162,7 +180,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -172,18 +190,18 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def set_shift(self, shift: float):
def set_shift(self, shift: float) -> None:
self._shift = shift
def set_scale_factors(self, scale_factors: list, upscale_mode):
def set_scale_factors(self, scale_factors: list[float], upscale_mode: str) -> None:
"""
Sets scale factors for a scale-wise generation regime.
Args:
scale_factors (`list`):
The scale factors for each step
scale_factors (`list[float]`):
The scale factors for each step.
upscale_mode (`str`):
Upscaling method
Upscaling method.
"""
self._scale_factors = scale_factors
self._upscale_mode = upscale_mode
@@ -238,16 +256,18 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return sample
def _sigma_to_t(self, sigma):
def _sigma_to_t(self, sigma: float | torch.FloatTensor) -> float | torch.FloatTensor:
return sigma * self.config.num_train_timesteps
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
def time_shift(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
def stretch_shift_to_terminal(self, t: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
@@ -256,12 +276,13 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
t (`torch.Tensor` or `np.ndarray`):
A tensor or numpy array of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
`torch.Tensor` or `np.ndarray`:
A tensor or numpy array of adjusted timesteps such that the final value equals
`self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
@@ -270,12 +291,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int = None,
device: str | torch.device = None,
num_inference_steps: int | None = None,
device: str | torch.device | None = None,
sigmas: list[float] | None = None,
mu: float = None,
mu: float | None = None,
timesteps: list[float] | None = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -317,43 +338,45 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
is_timesteps_provided = timesteps is not None
if is_timesteps_provided:
timesteps = np.array(timesteps).astype(np.float32)
timesteps = np.array(timesteps).astype(np.float32) # type: ignore
if sigmas is None:
if timesteps is None:
timesteps = np.linspace(
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
timesteps = np.linspace( # type: ignore
self._sigma_to_t(self.sigma_max),
self._sigma_to_t(self.sigma_min),
num_inference_steps,
)
sigmas = timesteps / self.config.num_train_timesteps
sigmas = timesteps / self.config.num_train_timesteps # type: ignore
else:
sigmas = np.array(sigmas).astype(np.float32)
sigmas = np.array(sigmas).astype(np.float32) # type: ignore
num_inference_steps = len(sigmas)
# 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of
# "exponential" or "linear" type is applied
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
sigmas = self.time_shift(mu, 1.0, sigmas) # type: ignore
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) # type: ignore
# 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
sigmas = self.stretch_shift_to_terminal(sigmas) # type: ignore
# 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) # type: ignore
# 5. Convert sigmas and timesteps to tensors and move to specified device
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) # type: ignore
if not is_timesteps_provided:
timesteps = sigmas * self.config.num_train_timesteps
timesteps = sigmas * self.config.num_train_timesteps # type: ignore
else:
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device)
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=device) # type: ignore
# 6. Append the terminal sigma value.
# If a model requires inverted sigma schedule for denoising but timesteps without inversion, the
@@ -370,7 +393,11 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
self._step_index = None
self._begin_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
def index_for_timestep(
self,
timestep: float | torch.Tensor,
schedule_timesteps: torch.Tensor | None = None,
) -> int:
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
@@ -382,9 +409,9 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
return int(indices[pos].item())
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: float | torch.Tensor) -> None:
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -459,7 +486,12 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
size = [round(self._scale_factors[self._step_index] * size) for size in self._init_size]
x0_pred = torch.nn.functional.interpolate(x0_pred, size=size, mode=self._upscale_mode)
noise = randn_tensor(x0_pred.shape, generator=generator, device=x0_pred.device, dtype=x0_pred.dtype)
noise = randn_tensor(
x0_pred.shape,
generator=generator,
device=x0_pred.device,
dtype=x0_pred.dtype,
)
prev_sample = (1 - sigma_next) * x0_pred + sigma_next * noise
# upon completion increase step index by one
@@ -473,7 +505,7 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
return FlowMatchLCMSchedulerOutput(prev_sample=prev_sample)
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -594,11 +626,15 @@ class FlowMatchLCMScheduler(SchedulerMixin, ConfigMixin):
)
return sigmas
def _time_shift_exponential(self, mu, sigma, t):
def _time_shift_exponential(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def _time_shift_linear(self, mu, sigma, t):
def _time_shift_linear(
self, mu: float, sigma: float, t: float | np.ndarray | torch.Tensor
) -> float | np.ndarray | torch.Tensor:
return mu / (mu + (1 / t - 1) ** sigma)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -86,6 +86,7 @@ from .import_utils import (
is_inflect_available,
is_invisible_watermark_available,
is_kernels_available,
is_kernels_version,
is_kornia_available,
is_librosa_available,
is_matplotlib_available,

View File

@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
@cache
def is_kernels_version(operation: str, version: str):
"""
Compares the current Kernels version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _kernels_available:
return False
return compare_versions(parse(_kernels_version), operation, version)
@cache
def is_hf_hub_version(operation: str, version: str):
"""

View File

@@ -465,8 +465,7 @@ class UNetTesterMixin:
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
@@ -481,9 +480,9 @@ class UNetTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class ModelTesterMixin:

View File

@@ -287,9 +287,8 @@ class ModelTesterMixin:
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -309,9 +308,8 @@ class ModelTesterMixin:
new_model.to(torch_device)
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -339,9 +337,8 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict, return_dict=False)[0]
second = model(**inputs_dict, return_dict=False)[0]
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
@@ -398,9 +395,8 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@@ -527,10 +523,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
@@ -569,10 +563,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
@@ -622,10 +614,8 @@ class ModelTesterMixin:
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"

View File

@@ -92,6 +92,9 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),

View File

@@ -15,7 +15,6 @@
import gc
import json
import logging
import os
import re
@@ -24,12 +23,10 @@ import safetensors.torch
import torch
import torch.nn as nn
from diffusers.utils import logging as diffusers_logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import check_if_dicts_are_equal
from ...testing_utils import (
CaptureLogger,
assert_tensors_close,
backend_empty_cache,
is_lora,
@@ -480,7 +477,10 @@ class LoraHotSwappingForModelTesterMixin:
with pytest.raises(RuntimeError, match=msg):
model.enable_lora_hotswap(target_rank=32)
def test_enable_lora_hotswap_called_after_adapter_added_warning(self):
def test_enable_lora_hotswap_called_after_adapter_added_warning(self, caplog):
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -488,26 +488,21 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
assert msg in str(cap_logger.out), f"Expected warning not found. Captured: {cap_logger.out}"
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
logger = diffusers_logging.get_logger("diffusers.loaders.peft")
logger.setLevel(logging.WARNING)
with CaptureLogger(logger) as cap_logger:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert cap_logger.out == "", f"Expected no warnings but found: {cap_logger.out}"
assert len(caplog.records) == 0
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -520,6 +515,9 @@ class LoraHotSwappingForModelTesterMixin:
model.enable_lora_hotswap(target_rank=32, check_compiled="wrong-argument")
def test_hotswap_second_adapter_targets_more_layers_raises(self, tmp_path, caplog):
# check the error and log
import logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
@@ -24,39 +26,64 @@ from ...testing_utils import (
slow,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
_LAYERWISE_CASTING_XFAIL_REASON = (
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
)
class UNet1DTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (standard variant)."""
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet1DModel
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (14, 16)
return (4, 14, 16)
@property
def main_input_name(self):
return "sample"
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
def get_init_dict(self):
return {
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def test_determinism(self):
super().test_determinism()
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
@@ -70,40 +97,18 @@ class UNet1DTesterConfig(BaseModelTesterConfig):
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "swish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DHubLoading(UNet1DTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.get_dummy_inputs())
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@@ -126,7 +131,12 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@slow
def test_unet_1d_maestro(self):
@@ -147,29 +157,98 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
# =============================================================================
# UNet1D RL (Value Function) Model Tests
# =============================================================================
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
class UNet1DRLTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (RL value function variant)."""
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet1DModel
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (1,)
return (4, 14, 1)
@property
def main_input_name(self):
return "sample"
def test_determinism(self):
super().test_determinism()
def get_init_dict(self):
return {
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
@@ -185,54 +264,18 @@ class UNet1DRLTesterConfig(BaseModelTesterConfig):
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
@torch.no_grad()
def test_output(self):
# UNetRL is a value-function with different output shape (batch, 1)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
assert value_function is not None
assert len(vf_loading_info["missing_keys"]) == 0
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
value_function.to(torch_device)
image = value_function(**self.get_dummy_inputs())
image = value_function(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@@ -256,4 +299,31 @@ class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass

View File

@@ -15,11 +15,12 @@
import gc
import math
import unittest
import pytest
import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from ...testing_utils import (
backend_empty_cache,
@@ -30,40 +31,39 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
# =============================================================================
# Standard UNet2D Model Tests
# =============================================================================
class UNet2DTesterConfig(BaseModelTesterConfig):
"""Base configuration for standard UNet2DModel testing."""
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
@@ -74,22 +74,11 @@ class UNet2DTesterConfig(BaseModelTesterConfig):
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_mid_block_attn_groups(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 4
@@ -98,11 +87,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
model.to(torch_device)
model.eval()
assert model.mid_block.attentions[0].group_norm is not None, (
"Mid block Attention group norm should exist but does not."
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
)
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
"Mid block Attention group norm does not have the expected number of groups."
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
)
with torch.no_grad():
@@ -111,15 +102,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_mid_block_none(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
mid_none_init_dict = self.get_init_dict()
mid_none_inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
@@ -130,7 +119,7 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
mid_none_model.to(torch_device)
mid_none_model.eval()
assert mid_none_model.mid_block is None, "Mid block should not exist."
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
with torch.no_grad():
output = model(**inputs_dict)
@@ -144,10 +133,8 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
@@ -156,32 +143,41 @@ class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
"UpBlock2D",
"DownBlock2D",
}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
attention_head_dim = 8
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
# =============================================================================
# UNet2D LDM Model Tests
# =============================================================================
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel LDM variant testing."""
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
@@ -191,34 +187,17 @@ class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.get_dummy_inputs()).sample
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@@ -226,7 +205,7 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.get_dummy_inputs()).sample
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@@ -286,31 +265,44 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 32
block_out_channels = (32, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
# =============================================================================
# NCSN++ Model Tests
# =============================================================================
class NCSNppTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel NCSN++ variant testing."""
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64, 64, 64],
"in_channels": 3,
"layers_per_block": 1,
@@ -332,71 +324,17 @@ class NCSNppTesterConfig(BaseModelTesterConfig):
"SkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
}
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_keep_in_fp32_modules(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_from_save_pretrained_dtype_inference(self):
pass
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_training(self):
pass
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestNCSNppHubLoading(NCSNppTesterConfig):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
inputs = self.get_dummy_inputs()
inputs = self.dummy_input
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
@@ -423,7 +361,7 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
@@ -444,4 +382,35 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# not required for this model
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
block_out_channels = (32, 64, 64, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, block_out_channels=block_out_channels
)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass

View File

@@ -20,7 +20,6 @@ import tempfile
import unittest
from collections import OrderedDict
import pytest
import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized
@@ -53,24 +52,17 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
IPAdapterTesterMixin,
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
UNetTesterMixin,
)
if is_peft_available():
from peft import LoraConfig
from ..testing_utils.lora import check_if_lora_correctly_set
from peft.tuners.tuners_utils import BaseTunerLayer
logger = logging.get_logger(__name__)
@@ -90,6 +82,16 @@ def get_unet_lora_config():
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
@@ -352,28 +354,34 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DConditionModel testing."""
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.34, 0.4]
@property
def model_class(self):
return UNet2DConditionModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def output_shape(self) -> tuple[int, int, int]:
def input_shape(self):
return (4, 16, 16)
@property
def model_split_percents(self) -> list[float]:
return [0.5, 0.34, 0.4]
def output_shape(self):
return (4, 16, 16)
@property
def main_input_name(self) -> str:
return "sample"
def get_init_dict(self) -> dict:
"""Return UNet2D model initialization arguments."""
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
@@ -385,24 +393,26 @@ class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
"""Return dummy inputs for UNet2D model."""
batch_size = 4
num_channels = 4
sizes = (16, 16)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_model_with_attention_head_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -417,13 +427,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
@@ -437,13 +446,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (8, 8)
@@ -457,13 +465,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_simple_projection(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -482,13 +489,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_class_embeddings_concat(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -508,287 +514,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
"""Hub checkpoint loading tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for UNet2DConditionModel."""
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
"""Test that deprecated load_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
"""Test that deprecated save_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
model.save_attn_procs(os.path.join(tmpdirname))
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for UNet2DConditionModel."""
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
"""Training tests for UNet2DConditionModel."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
"""Attention processor tests for UNet2DConditionModel."""
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -813,7 +544,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
assert output is not None
def test_model_sliceable_head_dim(self):
init_dict = self.get_init_dict()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -831,6 +562,21 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
@@ -872,8 +618,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -900,8 +645,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
]
)
def test_model_xattn_mask(self, mask_dtype):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
@@ -931,13 +675,39 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
)
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
"""Custom Diffusion processor tests for UNet2DConditionModel."""
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_custom_diffusion_processors(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -963,8 +733,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
assert (sample1 - sample2).abs().max() < 3e-3
def test_custom_diffusion_save_load(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -984,7 +754,7 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
@@ -1003,8 +773,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_custom_diffusion_xformers_on_off(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1028,28 +798,41 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for UNet2DConditionModel."""
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@property
def ip_adapter_processor_cls(self):
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
model = self.model_class(**init_dict)
model.to(torch_device)
def create_ip_adapter_state_dict(self, model):
return create_ip_adapter_state_dict(model)
with torch.no_grad():
sample = model(**inputs_dict).sample
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
return inputs_dict
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1122,8 +905,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1195,16 +977,185 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for UNet2DConditionModel."""
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
def test_torch_compile_repeated_blocks(self):
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)
warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for UNet2DConditionModel."""
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
@slow

View File

@@ -18,44 +18,47 @@ import unittest
import numpy as np
import torch
from diffusers import UNet3DConditionModel
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
)
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
logger = logging.get_logger(__name__)
@skip_mps
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet3DConditionModel testing."""
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
@property
def model_class(self):
return UNet3DConditionModel
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 4, 16, 16)
@property
def output_shape(self):
return (4, 4, 16, 16)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
@@ -70,25 +73,27 @@ class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
return {
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
@@ -102,74 +107,39 @@ class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
# Overriding since the UNet3D outputs a different structure.
@torch.no_grad()
def test_determinism(self):
model = self.model_class(**self.get_init_dict())
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
assert max_diff <= 1e-5
def test_feed_forward_chunking(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
assert output.shape == output_2.shape, "Shape doesn't match"
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
self.assertLessEqual(max_diff, 1e-5)
def test_model_attention_slicing(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
@@ -192,3 +162,22 @@ class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterM
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2

View File

@@ -13,42 +13,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetControlNetXSModel testing."""
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
@property
def model_class(self):
return UNetControlNetXSModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
@@ -63,23 +80,11 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
"conditioning_scale": 1,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_unet(self):
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
@@ -94,16 +99,10 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""Build the ControlNetXS-Adapter from a UNet."""
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
pass
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
@@ -116,7 +115,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
@@ -153,7 +152,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
@@ -194,12 +193,12 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
for p in module.parameters():
assert p.requires_grad
init_dict = self.get_init_dict()
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
@@ -237,7 +236,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert_frozen(u.upsamplers)
# # check controlnet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
@@ -268,6 +267,16 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):
unet = self.get_dummy_unet()
@@ -278,7 +287,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.get_dummy_inputs()
input_ = self.dummy_input
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
@@ -303,7 +312,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.get_dummy_inputs()
input_ = self.dummy_input
with torch.no_grad():
output = model(**input_).sample
@@ -311,14 +320,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert output.shape == output_mix_time.shape
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass

View File

@@ -16,10 +16,10 @@
import copy
import unittest
import pytest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
@@ -28,34 +28,45 @@ from ...testing_utils import (
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
@skip_mps
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
@property
def model_class(self):
return UNetSpatioTemporalConditionModel
def dummy_input(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
@property
def input_shape(self):
return (2, 2, 4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
@property
def fps(self):
return 6
@@ -72,8 +83,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
def addition_time_embed_dim(self):
return 32
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
@@ -92,23 +103,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
"addition_time_embed_dim": self.addition_time_embed_dim,
}
def get_dummy_inputs(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def _get_add_time_ids(self, do_classifier_free_guidance=True):
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
@@ -128,15 +124,43 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
return add_time_ids
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Number of Norm Groups is not configurable")
@unittest.skip("Number of Norm Groups is not configurable")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("Deprecated functionality")
def test_model_attention_slicing(self):
pass
@unittest.skip("Not supported")
def test_model_with_use_linear_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_simple_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_class_embeddings_concat(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_with_num_attention_heads_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
@@ -149,13 +173,12 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
@@ -169,13 +192,27 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
num_attention_heads = (8, 16)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, num_attention_heads=num_attention_heads
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
@@ -188,33 +225,3 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -1,7 +1,6 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
@@ -11,17 +10,11 @@ from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}