mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-25 04:10:34 +08:00
Compare commits
2 Commits
attn-backe
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
022ac4ddf6 | ||
|
|
1f6ac1c3d1 |
@@ -111,7 +111,7 @@ if __name__ == "__main__":
|
||||
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
||||
|
||||
```bash
|
||||
torchrun run_distributed.py --nproc_per_node=2
|
||||
torchrun --nproc_per_node=2 run_distributed.py
|
||||
```
|
||||
|
||||
## device_map
|
||||
|
||||
@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
# Handle ABI mismatch or other import failures gracefully.
|
||||
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||
_CAN_USE_FLASH_ATTN = False
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
@@ -83,26 +95,47 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLASH_ATTN_3 = False
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
try:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_AITER_ATTN = False
|
||||
aiter_flash_attn_func = None
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
try:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_SAGE_ATTN = False
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_varlen = None
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
@@ -113,26 +146,48 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
try:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLEX_ATTN = False
|
||||
flex_attention = None
|
||||
else:
|
||||
flex_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_NPU_ATTN:
|
||||
from torch_npu import npu_fusion_attention
|
||||
try:
|
||||
from torch_npu import npu_fusion_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_NPU_ATTN = False
|
||||
npu_fusion_attention = None
|
||||
else:
|
||||
npu_fusion_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XLA_ATTN:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
try:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XLA_ATTN = False
|
||||
xla_flash_attention = None
|
||||
else:
|
||||
xla_flash_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XFORMERS_ATTN:
|
||||
import xformers.ops as xops
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XFORMERS_ATTN = False
|
||||
xops = None
|
||||
else:
|
||||
xops = None
|
||||
|
||||
@@ -158,8 +213,6 @@ else:
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
|
||||
Reference in New Issue
Block a user