mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-25 12:20:33 +08:00
Compare commits
6 Commits
sayakpaul-
...
cp-attn-ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad9ac8dba6 | ||
|
|
acfa871347 | ||
|
|
547f3df0a0 | ||
|
|
1d12bd215f | ||
|
|
e7317067ab | ||
|
|
3f36c6d4f4 |
@@ -1865,9 +1865,12 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
# lse must be 4-D to broadcast with out (B, S, H, D).
|
||||
# Some backends (e.g. cuDNN on torch>=2.9) already return a
|
||||
# trailing-1 dim; others (e.g. flash-hub / native-flash) always
|
||||
# return 3-D lse, so we add the dim here when needed.
|
||||
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if lse.ndim == 3:
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
@@ -2154,10 +2157,11 @@ def _templated_unified_attention(
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
# lse from TemplatedRingAttention is 3-D (B, S, H_LOCAL) after its
|
||||
# final squeeze(-1). SeqAllToAllDim requires a 4-D input, so we add
|
||||
# the trailing dim here and remove it after the collective.
|
||||
# See: https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if lse.ndim == 3:
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
|
||||
@@ -13,7 +13,7 @@ from .compile import TorchCompileTesterMixin
|
||||
from .ip_adapter import IPAdapterTesterMixin
|
||||
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
|
||||
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
|
||||
from .parallelism import ContextParallelTesterMixin
|
||||
from .parallelism import ContextParallelAttentionBackendsTesterMixin, ContextParallelTesterMixin
|
||||
from .quantization import (
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesConfigMixin,
|
||||
@@ -45,6 +45,7 @@ __all__ = [
|
||||
"BitsAndBytesTesterMixin",
|
||||
"CacheTesterMixin",
|
||||
"ContextParallelTesterMixin",
|
||||
"ContextParallelAttentionBackendsTesterMixin",
|
||||
"CPUOffloadTesterMixin",
|
||||
"FasterCacheConfigMixin",
|
||||
"FasterCacheTesterMixin",
|
||||
|
||||
@@ -23,10 +23,8 @@ import torch.multiprocessing as mp
|
||||
|
||||
from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
)
|
||||
from ...testing_utils import is_context_parallel, is_kernels_available, require_torch_multi_accelerator
|
||||
from .utils import _maybe_cast_to_bf16
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
@@ -38,7 +36,9 @@ def _find_free_port():
|
||||
return port
|
||||
|
||||
|
||||
def _context_parallel_worker(rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict):
|
||||
def _context_parallel_worker(
|
||||
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
|
||||
):
|
||||
"""Worker function for context parallel testing."""
|
||||
try:
|
||||
# Set up distributed environment
|
||||
@@ -59,6 +59,9 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Cast as needed.
|
||||
model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
@@ -67,6 +70,13 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable attention backend
|
||||
if attention_backend:
|
||||
try:
|
||||
model.set_attention_backend(attention_backend)
|
||||
except Exception as e:
|
||||
pytest.skip(f"Skipping test because of exception: {e}.")
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
@@ -126,3 +136,76 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelAttentionBackendsTesterMixin:
|
||||
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"])
|
||||
@pytest.mark.parametrize(
|
||||
"attention_backend",
|
||||
[
|
||||
"native",
|
||||
pytest.param(
|
||||
"flash_hub",
|
||||
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
|
||||
),
|
||||
pytest.param(
|
||||
"_flash_3_hub",
|
||||
marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."),
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("ulysses_anything", [True, False])
|
||||
@torch.no_grad()
|
||||
def test_context_parallel_attn_backend_inference(self, cp_type, attention_backend, ulysses_anything):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if getattr(self.model_class, "_cp_plan", None) is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
if cp_type == "ring_degree":
|
||||
if attention_backend == "native":
|
||||
pytest.skip("Skipping test because ulysses isn't supported with native attention backend.")
|
||||
|
||||
if ulysses_anything and "ulysses" not in cp_type:
|
||||
pytest.skip("Skipping test as ulysses anything needs the ulysses degree set.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
# Move all tensors to CPU for multiprocessing
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
if ulysses_anything:
|
||||
cp_dict.update({"ulysses_anything": ulysses_anything})
|
||||
|
||||
# Find a free port for distributed communication
|
||||
master_port = _find_free_port()
|
||||
|
||||
# Use multiprocessing manager for cross-process communication
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
# Spawn worker processes
|
||||
mp.spawn(
|
||||
_context_parallel_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
attention_backend,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
22
tests/models/testing_utils/utils.py
Normal file
22
tests/models/testing_utils/utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName
|
||||
|
||||
|
||||
_BF16_REQUIRED_BACKENDS = {
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
}
|
||||
|
||||
|
||||
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
||||
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
||||
if not backend or backend not in _BF16_REQUIRED_BACKENDS:
|
||||
return model, inputs_dict
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
inputs_dict = {
|
||||
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return model, inputs_dict
|
||||
@@ -29,6 +29,7 @@ from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelAttentionBackendsTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
@@ -228,6 +229,12 @@ class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextPar
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallelAttnBackends(
|
||||
FluxTransformerTesterConfig, ContextParallelAttentionBackendsTesterMixin
|
||||
):
|
||||
"""Context Parallel inference x attention backends tests for Flux Transformer"""
|
||||
|
||||
|
||||
class TestFluxTransformerIPAdapter(FluxTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for Flux Transformer."""
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ OPTIONAL_TESTERS = [
|
||||
# Other testers
|
||||
("SingleFileTesterMixin", "single_file"),
|
||||
("IPAdapterTesterMixin", "ip_adapter"),
|
||||
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
|
||||
]
|
||||
|
||||
|
||||
@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
|
||||
|
||||
for tester, flag in OPTIONAL_TESTERS:
|
||||
if flag in include_optional:
|
||||
if tester not in testers:
|
||||
if tester == "ContextParallelAttentionBackendsTesterMixin":
|
||||
if (
|
||||
"cp_attn" in include_optional
|
||||
and "_cp_plan" in model_info["attributes"]
|
||||
and model_info["attributes"]["_cp_plan"] is not None
|
||||
):
|
||||
testers.append(tester)
|
||||
elif tester not in testers:
|
||||
testers.append(tester)
|
||||
|
||||
return testers
|
||||
@@ -530,6 +538,7 @@ def main():
|
||||
"faster_cache",
|
||||
"single_file",
|
||||
"ip_adapter",
|
||||
"cp_attn",
|
||||
"all",
|
||||
],
|
||||
help="Optional testers to include",
|
||||
|
||||
Reference in New Issue
Block a user