mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-23 19:30:38 +08:00
Compare commits
2 Commits
remove-non
...
attn-backe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7220687151 | ||
|
|
73b159f2a1 |
@@ -1,4 +1,4 @@
|
||||
from .attention import AttentionTesterMixin
|
||||
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
@@ -38,6 +38,7 @@ from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionBackendTesterMixin",
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
|
||||
@@ -14,22 +14,105 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import is_kernels_available, is_torch_version
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level backend parameter sets for AttentionBackendTesterMixin
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
_KERNELS_AVAILABLE = is_kernels_available()
|
||||
|
||||
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
|
||||
|
||||
_PARAM_NATIVE_CUDNN = pytest.param(
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
id="native_cudnn",
|
||||
marks=pytest.mark.skipif(
|
||||
not _CUDA_AVAILABLE,
|
||||
reason="CUDA is required for _native_cudnn backend.",
|
||||
),
|
||||
)
|
||||
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
_PARAM_FLASH_HUB = pytest.param(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
id="flash_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
_PARAM_FLASH_3_HUB = pytest.param(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
id="flash_3_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# All backends under test.
|
||||
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
|
||||
|
||||
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
|
||||
_BF16_REQUIRED_BACKENDS = {
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
}
|
||||
|
||||
# Backends that perform non-deterministic operations and therefore cannot run when
|
||||
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
|
||||
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
|
||||
|
||||
|
||||
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
||||
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
||||
if backend not in _BF16_REQUIRED_BACKENDS:
|
||||
return model, inputs_dict
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
inputs_dict = {
|
||||
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return model, inputs_dict
|
||||
|
||||
|
||||
def _skip_if_backend_requires_nondeterminism(backend):
|
||||
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
|
||||
|
||||
This check is intentionally deferred to test execution time because
|
||||
enable_full_determinism() is typically called at module level in test files *after*
|
||||
the module-level pytest.param() objects in this file have already been evaluated,
|
||||
making it impossible to catch via a collection-time skipif condition.
|
||||
"""
|
||||
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
|
||||
pytest.skip(
|
||||
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
|
||||
f"while `torch.use_deterministic_algorithms(True)` is active."
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
@@ -39,7 +122,6 @@ class AttentionTesterMixin:
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
@@ -179,3 +261,208 @@ class AttentionTesterMixin:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionBackendTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention backends on models. Following things are tested:
|
||||
|
||||
1. Backends can be set with the `attention_backend` context manager and with
|
||||
`set_attention_backend()` method.
|
||||
2. SDPA outputs don't deviate too much from backend outputs.
|
||||
3. Backend works with (regional) compilation.
|
||||
4. Backends can be restored.
|
||||
|
||||
Tests the backends using the model provided by the host test class. The backends to test
|
||||
are defined in `_ALL_BACKEND_PARAMS`.
|
||||
|
||||
Expected from the host test class:
|
||||
- model_class: The model class to instantiate.
|
||||
|
||||
Expected methods from the host test class:
|
||||
- get_init_dict(): Returns dict of kwargs to construct the model.
|
||||
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tolerance attributes — override in host class to loosen/tighten checks.
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# test_output_close_to_native: alternate backends (flash, cuDNN) may
|
||||
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
|
||||
backend_vs_native_atol: float = 1e-2
|
||||
backend_vs_native_rtol: float = 1e-2
|
||||
|
||||
# test_compile: regional compilation introduces the same kind of numerical
|
||||
# error as the non-compiled backend path, so the same loose tolerance applies.
|
||||
compile_vs_native_atol: float = 1e-2
|
||||
compile_vs_native_rtol: float = 1e-2
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_set_attention_backend_matches_context_manager(self, backend):
|
||||
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(backend):
|
||||
ctx_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
set_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
set_output,
|
||||
ctx_output,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
msg=(
|
||||
f"Output from model.set_attention_backend('{backend.value}') should be identical "
|
||||
f"to the output from `with attention_backend('{backend.value}'):`."
|
||||
),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_output_close_to_native(self, backend):
|
||||
"""All backends should produce model output numerically close to the native SDPA reference."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
backend_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
backend_output,
|
||||
native_output,
|
||||
atol=self.backend_vs_native_atol,
|
||||
rtol=self.backend_vs_native_rtol,
|
||||
msg=f"Output from {backend} should be numerically close to native SDPA.",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_context_manager_switches_and_restores_backend(self, backend):
|
||||
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
|
||||
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
with attention_backend(backend):
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert active_backend == backend, (
|
||||
f"Backend should be {backend} inside the context manager, got {active_backend}."
|
||||
)
|
||||
|
||||
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert restored_backend == initial_backend, (
|
||||
f"Backend should be restored to {initial_backend} after exiting the context manager, "
|
||||
f"got {restored_backend}."
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_compile(self, backend):
|
||||
"""
|
||||
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
|
||||
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
|
||||
as opposed to `model.compile`).
|
||||
"""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
if getattr(self.model_class, "_repeated_blocks", None) is None:
|
||||
pytest.skip("Skipping tests as regional compilation is not supported.")
|
||||
|
||||
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(
|
||||
"test_compile with the native backend requires torch >= 2.9.0 for stable "
|
||||
"fullgraph compilation with error_on_recompile=True."
|
||||
)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
torch.compiler.reset()
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
):
|
||||
with torch.no_grad():
|
||||
compile_output = model(**inputs_dict, return_dict=False)[0]
|
||||
model(**inputs_dict, return_dict=False)
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
compile_output,
|
||||
native_output,
|
||||
atol=self.compile_vs_native_atol,
|
||||
rtol=self.compile_vs_native_rtol,
|
||||
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionBackendTesterMixin,
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
@@ -224,6 +225,10 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
|
||||
"""Attention backend tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
"""
|
||||
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
|
||||
|
||||
Once attention backends become more mature, we can consider including this in our CI.
|
||||
|
||||
To run this test suite:
|
||||
|
||||
```bash
|
||||
export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
||||
|
||||
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
||||
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
||||
aiter 0.1.5.post4.dev20+ga25e55e79.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||
)
|
||||
from diffusers import FluxPipeline # noqa: E402
|
||||
from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
||||
)
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
||||
True,
|
||||
)
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
INFER_KW = {
|
||||
"prompt": "dance doggo dance",
|
||||
"height": 256,
|
||||
"width": 256,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.5,
|
||||
"max_sequence_length": 128,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
|
||||
def _backend_is_probably_supported(pipe, name: str):
|
||||
try:
|
||||
pipe.transformer.set_attention_backend(name)
|
||||
return pipe, True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_if_slices_match(output, expected_slice):
|
||||
img = output.images.detach().cpu()
|
||||
generated_slice = img.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for these tests.")
|
||||
return torch.device("cuda:0")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pipe(device):
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice):
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,expected_slice,error_on_recompile",
|
||||
COMPILE_CASES,
|
||||
ids=[c[0] for c in COMPILE_CASES],
|
||||
)
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
modified_pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
torch.compiler.reset()
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
|
||||
):
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
Reference in New Issue
Block a user