mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-20 15:38:11 +08:00
Compare commits
1 Commits
fa4
...
export-saf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc8817604a |
@@ -143,7 +143,6 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -229,7 +229,6 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -359,11 +358,6 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -527,7 +521,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -582,7 +575,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
@@ -2683,37 +2676,6 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,21 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
Reference in New Issue
Block a user