Compare commits

...

6 Commits

Author SHA1 Message Date
sayakpaul
ad9ac8dba6 generate. 2026-02-24 17:14:45 +05:30
sayakpaul
acfa871347 fix ring for flash and flash_3 2026-02-24 17:07:00 +05:30
sayakpaul
547f3df0a0 up 2026-02-24 15:48:49 +05:30
sayakpaul
1d12bd215f up 2026-02-24 15:30:11 +05:30
sayakpaul
e7317067ab up 2026-02-24 15:25:35 +05:30
sayakpaul
3f36c6d4f4 tests: add cp backend and attention backend tests. 2026-02-24 15:13:08 +05:30
6 changed files with 140 additions and 14 deletions

View File

@@ -1865,9 +1865,12 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse must be 4-D to broadcast with out (B, S, H, D).
# Some backends (e.g. cuDNN on torch>=2.9) already return a
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
# return 3-D lse, so we add the dim here when needed.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
@@ -2154,10 +2157,11 @@ def _templated_unified_attention(
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
# the trailing dim here and remove it after the collective.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)

View File

@@ -13,7 +13,7 @@ from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
@@ -45,6 +45,7 @@ __all__ = [
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"ContextParallelAttentionBackendsTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",

View File

@@ -23,10 +23,8 @@ import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator
from .utils import _maybe_cast_to_bf16
def _find_free_port():
@@ -38,7 +36,9 @@ def _find_free_port():
return port
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
def _context_parallel_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
):
"""Worker function for context parallel testing."""
try:
# Set up distributed environment
@@ -59,6 +59,9 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.to(device)
model.eval()
# Cast as needed.
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
@@ -67,6 +70,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
else:
inputs_on_device[key] = value
# Enable attention backend
if attention_backend:
try:
model.set_attention_backend(attention_backend)
except Exception as e:
pytest.skip(f"Skipping test because of exception: {e}.")
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)
@@ -126,3 +136,76 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelAttentionBackendsTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
@pytest.mark.parametrize(
"attention_backend",
[
"native",
pytest.param(
"flash_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"_flash_3_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
],
)
@pytest.mark.parametrize("ulysses_anything", [True, False])
@torch.no_grad()
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if getattr(self.model_class, "_cp_plan", None) is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
if attention_backend == "native":
pytest.skip("Skipping test because ulysses isn't supported with native attention backend.")
if ulysses_anything and "ulysses" not in cp_type:
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
if ulysses_anything:
cp_dict.update({"ulysses_anything": ulysses_anything})
# Find a free port for distributed communication
master_port = _find_free_port()
# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()
# Spawn worker processes
mp.spawn(
_context_parallel_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
inputs_dict,
return_dict,
attention_backend,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -0,0 +1,22 @@
import torch
from diffusers.models.attention_dispatch import AttentionBackendName
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict

View File

@@ -29,6 +29,7 @@ from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelAttentionBackendsTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
@@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
"""Context Parallel inference tests for Flux Transformer"""
class TestFluxTransformerContextParallelAttnBackends(
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
):
"""Context Parallel inference x attention backends tests for Flux Transformer"""
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""

View File

@@ -72,6 +72,7 @@ OPTIONAL_TESTERS = [
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
]
@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
if tester == "ContextParallelAttentionBackendsTesterMixin":
if (
"cp_attn" in include_optional
and "_cp_plan" in model_info["attributes"]
and model_info["attributes"]["_cp_plan"] is not None
):
testers.append(tester)
elif tester not in testers:
testers.append(tester)
return testers
@@ -530,6 +538,7 @@ def main():
"faster_cache",
"single_file",
"ip_adapter",
"cp_attn",
"all",
],
help="Optional testers to include",