mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-26 21:00:41 +08:00
Compare commits
7 Commits
attn-backe
...
cp-attn-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad9ac8dba6 | ||
|
|
acfa871347 | ||
|
|
547f3df0a0 | ||
|
|
1d12bd215f | ||
|
|
e7317067ab | ||
|
|
3f36c6d4f4 | ||
|
|
1f6ac1c3d1 |
@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
# Handle ABI mismatch or other import failures gracefully.
|
||||
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||
_CAN_USE_FLASH_ATTN = False
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
@@ -83,26 +95,47 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLASH_ATTN_3 = False
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
try:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_AITER_ATTN = False
|
||||
aiter_flash_attn_func = None
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
try:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_SAGE_ATTN = False
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_varlen = None
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
@@ -113,26 +146,48 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
try:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLEX_ATTN = False
|
||||
flex_attention = None
|
||||
else:
|
||||
flex_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_NPU_ATTN:
|
||||
from torch_npu import npu_fusion_attention
|
||||
try:
|
||||
from torch_npu import npu_fusion_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_NPU_ATTN = False
|
||||
npu_fusion_attention = None
|
||||
else:
|
||||
npu_fusion_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XLA_ATTN:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
try:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XLA_ATTN = False
|
||||
xla_flash_attention = None
|
||||
else:
|
||||
xla_flash_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XFORMERS_ATTN:
|
||||
import xformers.ops as xops
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XFORMERS_ATTN = False
|
||||
xops = None
|
||||
else:
|
||||
xops = None
|
||||
|
||||
@@ -158,8 +213,6 @@ else:
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
@@ -1812,9 +1865,12 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
# lse must be 4-D to broadcast with out (B, S, H, D).
|
||||
# Some backends (e.g. cuDNN on torch>=2.9) already return a
|
||||
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
|
||||
# return 3-D lse, so we add the dim here when needed.
|
||||
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if lse.ndim == 3:
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
@@ -2101,10 +2157,11 @@ def _templated_unified_attention(
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
|
||||
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
|
||||
# the trailing dim here and remove it after the collective.
|
||||
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if lse.ndim == 3:
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
|
||||
@@ -13,7 +13,7 @@ from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
@@ -45,6 +45,7 @@ __all__ = [
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"ContextParallelAttentionBackendsTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
|
||||
@@ -23,10 +23,8 @@ import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator
|
||||
from .utils import _maybe_cast_to_bf16
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
@@ -38,7 +36,9 @@ def _find_free_port():
|
||||
return port
|
||||
|
||||
|
||||
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
|
||||
def _context_parallel_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
|
||||
):
|
||||
"""Worker function for context parallel testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
@@ -59,6 +59,9 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Cast as needed.
|
||||
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
@@ -67,6 +70,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable attention backend
|
||||
if attention_backend:
|
||||
try:
|
||||
model.set_attention_backend(attention_backend)
|
||||
except Exception as e:
|
||||
pytest.skip(f"Skipping test because of exception: {e}.")
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
@@ -126,3 +136,76 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelAttentionBackendsTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
|
||||
@pytest.mark.parametrize(
|
||||
"attention_backend",
|
||||
[
|
||||
"native",
|
||||
pytest.param(
|
||||
"flash_hub",
|
||||
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
|
||||
),
|
||||
pytest.param(
|
||||
"_flash_3_hub",
|
||||
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("ulysses_anything", [True, False])
|
||||
@torch.no_grad()
|
||||
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if getattr(self.model_class, "_cp_plan", None) is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
if attention_backend == "native":
|
||||
pytest.skip("Skipping test because ulysses isn't supported with native attention backend.")
|
||||
|
||||
if ulysses_anything and "ulysses" not in cp_type:
|
||||
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
if ulysses_anything:
|
||||
cp_dict.update({"ulysses_anything": ulysses_anything})
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
attention_backend,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
22
tests/models/testing_utils/utils.py
Normal file
22
tests/models/testing_utils/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName
|
||||
|
||||
|
||||
_BF16_REQUIRED_BACKENDS = {
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
||||
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
||||
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
|
||||
return model, inputs_dict
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
inputs_dict = {
|
||||
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return model, inputs_dict
|
||||
@@ -29,6 +29,7 @@ from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelAttentionBackendsTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
@@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallelAttnBackends(
|
||||
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
|
||||
):
|
||||
"""Context Parallel inference x attention backends tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ OPTIONAL_TESTERS = [
|
||||
# Other testers
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
|
||||
]
|
||||
|
||||
|
||||
@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
|
||||
|
||||
for tester, flag in OPTIONAL_TESTERS:
|
||||
if flag in include_optional:
|
||||
if tester not in testers:
|
||||
if tester == "ContextParallelAttentionBackendsTesterMixin":
|
||||
if (
|
||||
"cp_attn" in include_optional
|
||||
and "_cp_plan" in model_info["attributes"]
|
||||
and model_info["attributes"]["_cp_plan"] is not None
|
||||
):
|
||||
testers.append(tester)
|
||||
elif tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
return testers
|
||||
@@ -530,6 +538,7 @@ def main():
|
||||
"faster_cache",
|
||||
"single_file",
|
||||
"ip_adapter",
|
||||
"cp_attn",
|
||||
"all",
|
||||
],
|
||||
help="Optional testers to include",
|
||||
|
||||
Reference in New Issue
Block a user