mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
[Bugfix] [ROCm] [AITER]: Fix aiter block quant not compatible with torch compile dynamo (#28716)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
137
tests/rocm/aiter/test_grouped_quant.py
Normal file
137
tests/rocm/aiter/test_grouped_quant.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# This is a test for the AITER group_fp8_quant op.
|
||||
# It tests if the AITER op is
|
||||
# 1. correctly defined the relationship between
|
||||
# implementation and fake function
|
||||
# 2. can be used with torch.compile
|
||||
# 3. can be used with CUDA graphs
|
||||
# This file will be skipped if AITER is not installed
|
||||
# and the platform is not ROCm.
|
||||
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# this import statement is needed to ensure the ops are registered
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Check if aiter package is installed
|
||||
aiter_available = importlib.util.find_spec("aiter") is not None
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
not (current_platform.is_rocm() and aiter_available),
|
||||
reason="AITER ops are only available on ROCm with aiter package installed",
|
||||
)
|
||||
|
||||
|
||||
def test_rocm_aiter_group_fp8_quant_fake_implementation():
|
||||
"""Test that the fake implementation is correctly
|
||||
defined for torch.ops.vllm.rocm_aiter_group_fp8_quant."""
|
||||
# Create test tensors
|
||||
M = 128
|
||||
N = 4096
|
||||
group_size = 128
|
||||
|
||||
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Verify the op's fake implementation using torch.library.opcheck
|
||||
# This checks that the fake function returns tensors with correct shapes and dtypes
|
||||
torch.library.opcheck(
|
||||
torch.ops.vllm.rocm_aiter_group_fp8_quant,
|
||||
(input_tensor, group_size),
|
||||
test_utils=("test_faketensor",),
|
||||
)
|
||||
|
||||
|
||||
def test_rocm_aiter_group_fp8_quant_torch_compile_with_cudagraph():
|
||||
"""Test that rocm_aiter_ops.group_fp8_quant
|
||||
with group size 128 can be used with
|
||||
torch.compile in cudagraph mode."""
|
||||
# Create test tensors
|
||||
M = 128
|
||||
N = 4096
|
||||
group_size = 128
|
||||
|
||||
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Define a function that uses the op
|
||||
def group_fp8_quant_fn(x):
|
||||
return rocm_aiter_ops.group_fp8_quant(x, group_size)
|
||||
|
||||
# Compile with cudagraph mode
|
||||
compiled_fn = torch.compile(
|
||||
group_fp8_quant_fn,
|
||||
fullgraph=True,
|
||||
backend="inductor",
|
||||
mode="reduce-overhead",
|
||||
dynamic=False,
|
||||
)
|
||||
|
||||
# Run eager mode
|
||||
x_fp8_eager, scales_eager = group_fp8_quant_fn(input_tensor)
|
||||
|
||||
# Run compiled version (first run will trigger compilation)
|
||||
x_fp8_compiled, scales_compiled = compiled_fn(input_tensor)
|
||||
|
||||
# Verify shapes match
|
||||
assert x_fp8_compiled.shape == x_fp8_eager.shape
|
||||
assert scales_compiled.shape == scales_eager.shape
|
||||
|
||||
# Verify expected shapes
|
||||
assert x_fp8_compiled.shape == (M, N)
|
||||
expected_scale_cols = (N + group_size - 1) // group_size
|
||||
assert scales_compiled.shape == (M, expected_scale_cols)
|
||||
|
||||
# Verify results match
|
||||
assert torch.allclose(
|
||||
x_fp8_compiled.to(torch.float32),
|
||||
x_fp8_eager.to(torch.float32),
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
assert torch.allclose(scales_compiled, scales_eager, rtol=1e-3, atol=1e-3)
|
||||
|
||||
# Test with different input (reusing compiled graph)
|
||||
input_tensor_2 = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
|
||||
x_fp8_eager_2, scales_eager_2 = group_fp8_quant_fn(input_tensor_2)
|
||||
x_fp8_compiled_2, scales_compiled_2 = compiled_fn(input_tensor_2)
|
||||
|
||||
# Verify second run also produces correct results
|
||||
assert torch.allclose(
|
||||
x_fp8_compiled_2.to(torch.float32),
|
||||
x_fp8_eager_2.to(torch.float32),
|
||||
rtol=1e-2,
|
||||
atol=1e-2,
|
||||
)
|
||||
assert torch.allclose(scales_compiled_2, scales_eager_2, rtol=1e-3, atol=1e-3)
|
||||
|
||||
|
||||
def test_rocm_aiter_group_fp8_quant_different_shapes():
|
||||
"""Test rocm_aiter_ops.group_fp8_quant with different input shapes."""
|
||||
group_size = 128
|
||||
|
||||
test_shapes = [
|
||||
(64, 2048),
|
||||
(256, 8192),
|
||||
(32, 1024),
|
||||
(512, 4096),
|
||||
]
|
||||
|
||||
for M, N in test_shapes:
|
||||
input_tensor = torch.randn((M, N), dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
x_fp8, scales = rocm_aiter_ops.group_fp8_quant(input_tensor, group_size)
|
||||
|
||||
# Verify shapes
|
||||
assert x_fp8.shape == (M, N)
|
||||
expected_scale_cols = (N + group_size - 1) // group_size
|
||||
assert scales.shape == (M, expected_scale_cols)
|
||||
|
||||
# Verify dtypes
|
||||
from aiter import dtypes
|
||||
|
||||
assert x_fp8.dtype == dtypes.fp8
|
||||
assert scales.dtype == torch.float32
|
||||
@@ -43,6 +43,36 @@ def if_aiter_supported(func: Callable) -> Callable:
|
||||
return wrapper
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_impl(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.shape[-1] % group_size == 0, "Input shape must be divisible by group size"
|
||||
from aiter import QuantType, dtypes, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=dtypes.fp8)
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_fake(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter import dtypes
|
||||
|
||||
M, N = x.shape
|
||||
x_fp8 = torch.empty((M, N), dtype=dtypes.fp8, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
(N + group_size - 1) // group_size,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
return x_fp8, out_bs
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -512,6 +542,14 @@ class rocm_aiter_ops:
|
||||
)
|
||||
|
||||
# register all the custom ops here
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_group_fp8_quant",
|
||||
op_func=_rocm_aiter_group_fp8_quant_impl,
|
||||
mutates_args=[],
|
||||
fake_impl=_rocm_aiter_group_fp8_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_aiter_asm_moe_tkw1",
|
||||
op_func=_rocm_aiter_asm_moe_tkw1_impl,
|
||||
@@ -887,14 +925,12 @@ class rocm_aiter_ops:
|
||||
return gemm_a8w8_blockscale(A, B, As, Bs, dtype=output_dtype)
|
||||
|
||||
@staticmethod
|
||||
def per_1x128_fp8_quant(
|
||||
def group_fp8_quant(
|
||||
input_2d: torch.Tensor,
|
||||
group_size: int = 128,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Only applies quantization method for fp8 data type only."""
|
||||
from aiter import QuantType, dtypes, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(input_2d.contiguous(), quant_dtype=dtypes.fp8)
|
||||
assert group_size == 128, "Group size must be 128"
|
||||
return torch.ops.vllm.rocm_aiter_group_fp8_quant(input_2d, group_size)
|
||||
|
||||
@staticmethod
|
||||
def is_triton_gemm_w8a8_tuned(n: int, k: int) -> bool:
|
||||
|
||||
@@ -342,7 +342,7 @@ class W8A8BlockFp8LinearOp:
|
||||
)
|
||||
# MI300 uses tuned AITER ASM/C++ kernel
|
||||
else:
|
||||
q_input, input_scale = rocm_aiter_ops.per_1x128_fp8_quant(input_2d)
|
||||
q_input, input_scale = rocm_aiter_ops.group_fp8_quant(input_2d)
|
||||
|
||||
return gemm_a8w8_blockscale_op(
|
||||
q_input,
|
||||
|
||||
Reference in New Issue
Block a user