Compare commits

..

7 Commits

Author SHA1 Message Date
sayakpaul
ad9ac8dba6 generate. 2026-02-24 17:14:45 +05:30
sayakpaul
acfa871347 fix ring for flash and flash_3 2026-02-24 17:07:00 +05:30
sayakpaul
547f3df0a0 up 2026-02-24 15:48:49 +05:30
sayakpaul
1d12bd215f up 2026-02-24 15:30:11 +05:30
sayakpaul
e7317067ab up 2026-02-24 15:25:35 +05:30
sayakpaul
3f36c6d4f4 tests: add cp backend and attention backend tests. 2026-02-24 15:13:08 +05:30
SYM.BOT
1f6ac1c3d1 fix: graceful fallback when attention backends fail to import (#13060)
* fix: graceful fallback when attention backends fail to import

## Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be
installed but fail to import at runtime due to ABI mismatches. For example,
when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8,
the import fails with:

```
OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
```

The current code uses `importlib.util.find_spec()` to check if packages exist,
but this only verifies the package is installed—not that it can actually be
imported. When the import fails, diffusers crashes instead of falling back to
native PyTorch attention.

## Solution

Wrap all external attention backend imports in try-except blocks that catch
`ImportError` and `OSError`. On failure:
1. Log a warning message explaining the issue
2. Set the corresponding `_CAN_USE_*` flag to `False`
3. Set the imported functions to `None`

This allows diffusers to gracefully degrade to PyTorch's native SDPA
(scaled_dot_product_attention) instead of crashing.

## Affected backends

- flash_attn (Flash Attention)
- flash_attn_3 (Flash Attention 3)
- aiter (AMD Instinct)
- sageattention (SageAttention)
- flex_attention (PyTorch Flex Attention)
- torch_npu (Huawei NPU)
- torch_xla (TPU/XLA)
- xformers (Meta xFormers)

## Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
Before: crashes on import. After: logs warning and uses native attention.

* address review: use single logger and catch RuntimeError

- Move logger to module level instead of creating per-backend loggers
- Add RuntimeError to exception list alongside ImportError and OSError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Apply style fixes

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-24 13:37:39 +05:30
8 changed files with 190 additions and 115 deletions

View File

@@ -60,16 +60,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -78,7 +68,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -135,7 +124,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError, RuntimeError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
_CAN_USE_FLASH_ATTN = False
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
else:
flash_attn_func = None
flash_attn_varlen_func = None
@@ -83,26 +95,47 @@ else:
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_3 = False
flash_attn_3_func = None
flash_attn_3_varlen_func = None
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
try:
from aiter import flash_attn_func as aiter_flash_attn_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
_CAN_USE_AITER_ATTN = False
aiter_flash_attn_func = None
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
try:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
_CAN_USE_SAGE_ATTN = False
sageattn = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_varlen = None
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
@@ -113,26 +146,48 @@ else:
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
try:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLEX_ATTN = False
flex_attention = None
else:
flex_attention = None
if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
try:
from torch_npu import npu_fusion_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
_CAN_USE_NPU_ATTN = False
npu_fusion_attention = None
else:
npu_fusion_attention = None
if _CAN_USE_XLA_ATTN:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
try:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
_CAN_USE_XLA_ATTN = False
xla_flash_attention = None
else:
xla_flash_attention = None
if _CAN_USE_XFORMERS_ATTN:
import xformers.ops as xops
try:
import xformers.ops as xops
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
_CAN_USE_XFORMERS_ATTN = False
xops = None
else:
xops = None
@@ -158,8 +213,6 @@ else:
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
@@ -1812,9 +1865,12 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse must be 4-D to broadcast with out (B, S, H, D).
# Some backends (e.g. cuDNN on torch>=2.9) already return a
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
# return 3-D lse, so we add the dim here when needed.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
@@ -2101,10 +2157,11 @@ def _templated_unified_attention(
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
# the trailing dim here and remove it after the collective.
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if lse.ndim == 3:
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -13,7 +13,7 @@ from .compile import TorchCompileTesterMixin
from .ip_adapter import IPAdapterTesterMixin
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
from .parallelism import ContextParallelTesterMixin
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
from .quantization import (
BitsAndBytesCompileTesterMixin,
BitsAndBytesConfigMixin,
@@ -45,6 +45,7 @@ __all__ = [
"BitsAndBytesTesterMixin",
"CacheTesterMixin",
"ContextParallelTesterMixin",
"ContextParallelAttentionBackendsTesterMixin",
"CPUOffloadTesterMixin",
"FasterCacheConfigMixin",
"FasterCacheTesterMixin",

View File

@@ -23,10 +23,8 @@ import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from ...testing_utils import (
is_context_parallel,
require_torch_multi_accelerator,
)
from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator
from .utils import _maybe_cast_to_bf16
def _find_free_port():
@@ -38,7 +36,9 @@ def _find_free_port():
return port
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
def _context_parallel_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
):
"""Worker function for context parallel testing."""
try:
# Set up distributed environment
@@ -59,8 +59,23 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.to(device)
model.eval()
# Cast as needed.
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable attention backend
if attention_backend:
try:
model.set_attention_backend(attention_backend)
except Exception as e:
pytest.skip(f"Skipping test because of exception: {e}.")
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -84,59 +99,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -175,47 +137,75 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelAttentionBackendsTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
"attention_backend",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
"native",
pytest.param(
"flash_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
pytest.param(
"_flash_3_hub",
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
@pytest.mark.parametrize("ulysses_anything", [True, False])
@torch.no_grad()
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
if getattr(self.model_class, "_cp_plan", None) is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
if attention_backend == "native":
pytest.skip("Skipping test because ulysses isn't supported with native attention backend.")
if ulysses_anything and "ulysses" not in cp_type:
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
inputs_dict = self.get_dummy_inputs()
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}
if ulysses_anything:
cp_dict.update({"ulysses_anything": ulysses_anything})
# Find a free port for distributed communication
master_port = _find_free_port()
# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()
# Spawn worker processes
mp.spawn(
_custom_mesh_worker,
_context_parallel_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
attention_backend,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -0,0 +1,22 @@
import torch
from diffusers.models.attention_dispatch import AttentionBackendName
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict

View File

@@ -29,6 +29,7 @@ from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
BitsAndBytesTesterMixin,
ContextParallelAttentionBackendsTesterMixin,
ContextParallelTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
@@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
"""Context Parallel inference tests for Flux Transformer"""
class TestFluxTransformerContextParallelAttnBackends(
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
):
"""Context Parallel inference x attention backends tests for Flux Transformer"""
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for Flux Transformer."""

View File

@@ -72,6 +72,7 @@ OPTIONAL_TESTERS = [
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
]
@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
for tester, flag in OPTIONAL_TESTERS:
if flag in include_optional:
if tester not in testers:
if tester == "ContextParallelAttentionBackendsTesterMixin":
if (
"cp_attn" in include_optional
and "_cp_plan" in model_info["attributes"]
and model_info["attributes"]["_cp_plan"] is not None
):
testers.append(tester)
elif tester not in testers:
testers.append(tester)
return testers
@@ -530,6 +538,7 @@ def main():
"faster_cache",
"single_file",
"ip_adapter",
"cp_attn",
"all",
],
help="Optional testers to include",