Compare commits

...

6 Commits

Author SHA1 Message Date
DN6
76062a74e0 update 2026-03-23 17:16:44 +05:30
Dhruv Nair
52558b45d8 [CI] Flux2 Model Test Refactor (#13071)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-23 16:56:08 +05:30
Sayak Paul
c02c17c6ee [tests] test load_components in modular (#13245)
* test load_components.

* fix

* fix

* u[

* up
2026-03-21 09:41:48 +05:30
Sayak Paul
a9855c4204 [tests] fix audioldm2 tests. (#13293)
fix audioldm2 tests.
2026-03-20 20:53:21 +05:30
Sayak Paul
0b35834351 [core] fa4 support. (#13280)
* start fa4 support.

* up

* specify minimum version
2026-03-20 17:28:09 +05:30
Sayak Paul
522b523e40 [ci] hoping to fix is_flaky with wanvace. (#13294)
* hoping to fix is_flaky with wanvace.

* revert changes in src/diffusers/utils/testing_utils.py and propagate them to tests/testing_utils.py.

* up
2026-03-20 16:02:16 +05:30
9 changed files with 822 additions and 172 deletions

View File

@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
if hasattr(self.language_model, "_get_initial_cache_position"):
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs

View File

@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import (
floats_tensor,
is_flaky,
require_peft_backend,
require_peft_version_greater,
skip_mps,
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"base_dim": 3,
"z_dim": 4,
"dim_mult": [1, 1, 1, 1],
"latents_mean": torch.randn(4).numpy().tolist(),
"latents_std": torch.randn(4).numpy().tolist(),
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}

View File

@@ -41,7 +41,6 @@ from ..testing_utils import (
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
@@ -219,6 +218,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
@@ -412,10 +415,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
"""BitsAndBytes + compile tests for Flux Transformer."""
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""

View File

@@ -13,48 +13,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import Flux2Transformer2DModel, attention_backend
from diffusers import Flux2Transformer2DModel
from diffusers.models.transformers.transformer_flux2 import (
Flux2KVAttnProcessor,
Flux2KVCache,
Flux2KVLayerCache,
Flux2KVParallelSelfAttnProcessor,
)
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 4)
def prepare_dummy_input(self, height=4, width=4):
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
# Skip setting testing with default: AttnProcessor
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
@@ -82,8 +129,286 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
"guidance": guidance,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
pass
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux2 Transformer."""
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux2 Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux2 Transformer."""
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux2 Transformer."""
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux2 Transformer."""
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux2 Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux2 Transformer."""
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux2 Transformer."""
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux2 Transformer."""
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
num_ref_tokens = 4
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -91,72 +416,210 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"timestep_guidance_channels": 256,
"axes_dims_rope": [4, 4, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
num_ref_tokens = self.num_ref_tokens
# TODO (Daniel, Sayak): We can remove this test.
def test_flux2_consistency(self, seed=0):
torch.manual_seed(seed)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ref_hidden_states = randn_tensor(
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
)
img_hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
torch.manual_seed(seed)
model = self.model_class(**init_dict)
# state_dict = model.state_dict()
# for key, param in state_dict.items():
# print(f"{key} | {param.shape}")
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
ref_t_coords = torch.arange(1)
ref_h_coords = torch.arange(num_ref_tokens)
ref_w_coords = torch.arange(1)
ref_l_coords = torch.arange(1)
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
image_ids = torch.cat([ref_ids, image_ids], dim=1)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
"""KV cache tests for Flux2 Transformer."""
def test_kv_layer_cache_store_and_get(self):
cache = Flux2KVLayerCache()
k = torch.randn(1, 4, 2, 16)
v = torch.randn(1, 4, 2, 16)
cache.store(k, v)
k_out, v_out = cache.get()
assert torch.equal(k, k_out)
assert torch.equal(v, v_out)
def test_kv_layer_cache_get_before_store_raises(self):
cache = Flux2KVLayerCache()
try:
cache.get()
assert False, "Expected RuntimeError"
except RuntimeError:
pass
def test_kv_layer_cache_clear(self):
cache = Flux2KVLayerCache()
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.k_ref is None
assert cache.v_ref is None
def test_kv_cache_structure(self):
num_double = 3
num_single = 2
cache = Flux2KVCache(num_double, num_single)
assert len(cache.double_block_caches) == num_double
assert len(cache.single_block_caches) == num_single
assert cache.num_ref_tokens == 0
for i in range(num_double):
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
for i in range(num_single):
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
def test_kv_cache_clear(self):
cache = Flux2KVCache(2, 1)
cache.num_ref_tokens = 4
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.num_ref_tokens == 0
assert cache.get_double(0).k_ref is None
def _set_kv_attn_processors(self, model):
for block in model.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in model.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
@torch.no_grad()
def test_extract_mode_returns_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self._set_kv_attn_processors(model)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
assert output.kv_cache is not None
assert isinstance(output.kv_cache, Flux2KVCache)
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
for layer_cache in output.kv_cache.double_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
for layer_cache in output.kv_cache.single_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
@torch.no_grad()
def test_extract_mode_output_shape(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with attention_backend("native"):
with torch.no_grad():
output = model(**inputs_dict)
height, width = 4, 4
output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output.sample.shape == (1, height * width, 4)
self.assertIsNotNone(output)
@torch.no_grad()
def test_cached_mode_uses_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
height, width = 4, 4
extract_output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
# Check against expected slice
# fmt: off
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
# fmt: on
base_config = Flux2TransformerTesterConfig()
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
cached_output = model(
**cached_inputs,
kv_cache=extract_output.kv_cache,
kv_cache_mode="cached",
)
flat_output = output.cpu().flatten()
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
assert cached_output.sample.shape == (1, height * width, 4)
assert cached_output.kv_cache is None
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@torch.no_grad()
def test_extract_return_dict_false(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
return_dict=False,
)
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[1], Flux2KVCache)
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
@torch.no_grad()
def test_no_kv_cache_mode_returns_no_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
base_config = Flux2TransformerTesterConfig()
output = model(**base_config.get_dummy_inputs())
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
assert output.kv_cache is None

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]
class SanaTransformer2DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaTransformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
def output_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 32, 32)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.9]
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
"in_channels": 4,
"out_channels": 4,
@@ -75,9 +77,53 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
"caption_channels": 8,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 2
num_channels = 4
height = 32
width = 32
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaTransformer2D(SanaTransformer2DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Transformer 2D."""
class TestSanaTransformer2DMemory(SanaTransformer2DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTraining(SanaTransformer2DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Transformer 2D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSanaTransformer2DAttention(SanaTransformer2DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Transformer 2D."""
class TestSanaTransformer2DCompile(SanaTransformer2DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Transformer 2D."""
class TestSanaTransformer2DBitsAndBytes(SanaTransformer2DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Transformer 2D."""
class TestSanaTransformer2DTorchAo(SanaTransformer2DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Transformer 2D."""

View File

@@ -1,3 +1,4 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,57 +13,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SanaVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class SanaVideoTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SanaVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def output_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (16, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | float | list[int] | tuple | str | bool]:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (16, 2, 16, 16)
@property
def output_shape(self):
return (16, 2, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"out_channels": 16,
"num_attention_heads": 2,
@@ -82,16 +80,56 @@ class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 16
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,)).to(torch_device),
}
class TestSanaVideoTransformer3D(SanaVideoTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DMemory(SanaVideoTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTraining(SanaVideoTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Sana Video Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = SanaVideoTransformer3DModel
class TestSanaVideoTransformer3DAttention(SanaVideoTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Sana Video Transformer 3D."""
def prepare_init_args_and_inputs_for_common(self):
return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestSanaVideoTransformer3DCompile(SanaVideoTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DBitsAndBytes(SanaVideoTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Sana Video Transformer 3D."""
class TestSanaVideoTransformer3DTorchAo(SanaVideoTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Sana Video Transformer 3D."""

View File

@@ -5,6 +5,7 @@ from typing import Callable
import pytest
import torch
from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
@@ -32,6 +33,33 @@ from ..testing_utils import (
)
def _get_specified_components(path_or_repo_id, cache_dir=None):
if os.path.isdir(path_or_repo_id):
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
else:
try:
config_path = hf_hub_download(
repo_id=path_or_repo_id,
filename="modular_model_index.json",
local_dir=cache_dir,
)
except Exception:
return None
with open(config_path) as f:
config = json.load(f)
components = set()
for k, v in config.items():
if isinstance(v, (str, int, float, bool)):
continue
for entry in v:
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
components.add(k)
break
return components
class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_load_expected_components_from_pretrained(self, tmp_path):
pipe = self.get_pipeline()
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
if not expected:
pytest.skip("Skipping test as we couldn't fetch the expected components.")
actual = {
name
for name in pipe.components
if getattr(pipe, name, None) is not None
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
def test_load_expected_components_from_save_pretrained(self, tmp_path):
pipe = self.get_pipeline()
save_dir = str(tmp_path / "saved-pipeline")
pipe.save_pretrained(save_dir)
expected = _get_specified_components(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
actual = {
name
for name in loaded_pipe.components
if getattr(loaded_pipe, name, None) is not None
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, (
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
)
def test_modular_index_consistency(self, tmp_path):
pipe = self.get_pipeline()
components_spec = pipe._component_specs