[small][batch invariance] Rename the env and internal flags to simplify usage (#26855)
Signed-off-by: Bram Wasti <bwasti@meta.com>
This commit is contained in:
@@ -5,11 +5,11 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// vllm_kernel_override_batch_invariant(); returns true
|
// vllm_is_batch_invariant(); returns true
|
||||||
// if env VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT=1
|
// if env VLLM_BATCH_INVARIANT=1
|
||||||
inline bool vllm_kernel_override_batch_invariant() {
|
inline bool vllm_is_batch_invariant() {
|
||||||
static bool cached = []() {
|
static bool cached = []() {
|
||||||
std::string env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT";
|
std::string env_key = "VLLM_BATCH_INVARIANT";
|
||||||
const char* val = std::getenv(env_key.c_str());
|
const char* val = std::getenv(env_key.c_str());
|
||||||
return (val && std::atoi(val) != 0) ? 1 : 0;
|
return (val && std::atoi(val) != 0) ? 1 : 0;
|
||||||
}();
|
}();
|
||||||
|
|||||||
@@ -426,7 +426,7 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
|||||||
wt_ptr % req_alignment_bytes == 0;
|
wt_ptr % req_alignment_bytes == 0;
|
||||||
bool offsets_are_multiple_of_vector_width =
|
bool offsets_are_multiple_of_vector_width =
|
||||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width &&
|
||||||
!batch_invariant_launch) {
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
@@ -474,7 +474,7 @@ void poly_norm(torch::Tensor& out, // [..., hidden_size]
|
|||||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
auto out_ptr = reinterpret_cast<std::uintptr_t>(out.data_ptr());
|
||||||
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
bool ptrs_are_aligned = inp_ptr % 16 == 0 && out_ptr % 16 == 0;
|
||||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && !batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_POLY_NORM(8);
|
LAUNCH_FUSED_POLY_NORM(8);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -254,7 +254,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
|||||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
bool ptrs_are_aligned =
|
bool ptrs_are_aligned =
|
||||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
bool batch_invariant_launch = vllm::vllm_kernel_override_batch_invariant();
|
bool batch_invariant_launch = vllm::vllm_is_batch_invariant();
|
||||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0 &&
|
||||||
!batch_invariant_launch) {
|
!batch_invariant_launch) {
|
||||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
|
|||||||
|
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||||
# m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1")
|
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||||
|
|
||||||
outputs: list[tuple[str, list]] = []
|
outputs: list[tuple[str, list]] = []
|
||||||
for test_preemption in [False, True]:
|
for test_preemption in [False, True]:
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ hopper_only = pytest.mark.skipif(
|
|||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_batch_invariant_mode():
|
def enable_batch_invariant_mode():
|
||||||
"""Automatically enable batch invariant kernel overrides for all tests."""
|
"""Automatically enable batch invariant kernel overrides for all tests."""
|
||||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1"
|
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||||
yield
|
yield
|
||||||
# Restore original value after test
|
# Restore original value after test
|
||||||
if old_value is None:
|
if old_value is None:
|
||||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||||
else:
|
else:
|
||||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||||
|
|
||||||
|
|
||||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
|
|||||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||||
# all-reduce operations (custom all-reduce may not be deterministic)
|
# all-reduce operations (custom all-reduce may not be deterministic)
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
|
|
||||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
disable_custom_ar = vllm_is_batch_invariant()
|
||||||
|
|
||||||
if disable_custom_ar:
|
if disable_custom_ar:
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
|
||||||
# CRITICAL: Disable batch invariance for this test
|
# CRITICAL: Disable batch invariance for this test
|
||||||
old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT")
|
old_value = os.environ.get("VLLM_BATCH_INVARIANT")
|
||||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0"
|
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||||
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
|
|||||||
finally:
|
finally:
|
||||||
# Restore original value
|
# Restore original value
|
||||||
if old_value is None:
|
if old_value is None:
|
||||||
os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None)
|
os.environ.pop("VLLM_BATCH_INVARIANT", None)
|
||||||
else:
|
else:
|
||||||
os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value
|
os.environ["VLLM_BATCH_INVARIANT"] = old_value
|
||||||
|
|
||||||
|
|
||||||
@hopper_only
|
@hopper_only
|
||||||
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend):
|
|||||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||||
|
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
|
|
||||||
disable_custom_ar = vllm_kernel_override_batch_invariant()
|
disable_custom_ar = vllm_is_batch_invariant()
|
||||||
|
|
||||||
if disable_custom_ar:
|
if disable_custom_ar:
|
||||||
print(f"\n{'=' * 80}")
|
print(f"\n{'=' * 80}")
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from vllm.config.scheduler import RunnerType
|
|||||||
from vllm.config.utils import assert_hashable, config, getattr_iter
|
from vllm.config.utils import assert_hashable, config, getattr_iter
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
@@ -423,7 +423,7 @@ class ModelConfig:
|
|||||||
video_pruning_rate: float | None,
|
video_pruning_rate: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Enable batch invariance settings if requested
|
# Enable batch invariance settings if requested
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.enforce_eager = True
|
self.enforce_eager = True
|
||||||
|
|
||||||
# Set the default seed to 0 in V1.
|
# Set the default seed to 0 in V1.
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import vllm.envs as envs
|
|||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
from vllm.utils import cuda_device_count_stateless, get_open_ports_list
|
||||||
@@ -565,7 +565,7 @@ class ParallelConfig:
|
|||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
|
|
||||||
# Enable batch invariance settings if requested
|
# Enable batch invariance settings if requested
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.disable_custom_all_reduce = True
|
self.disable_custom_all_reduce = True
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import vllm.envs as envs
|
|||||||
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
from vllm.utils import cuda_device_count_stateless, update_environment_variables
|
||||||
|
|
||||||
@@ -74,7 +74,7 @@ def should_nccl_symm_mem_allreduce(world_size: int, input_tensor: torch.Tensor)
|
|||||||
is_symmetric_memory_enabled,
|
is_symmetric_memory_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_symmetric_memory_enabled():
|
if not is_symmetric_memory_enabled():
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from vllm.distributed.device_communicators.all_reduce_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -103,7 +103,7 @@ class SymmMemCommunicator:
|
|||||||
return
|
return
|
||||||
self.force_multimem = force_multimem
|
self.force_multimem = force_multimem
|
||||||
self.disabled = False
|
self.disabled = False
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.disabled = True
|
self.disabled = True
|
||||||
|
|
||||||
def should_use_symm_mem(self, inp: torch.Tensor):
|
def should_use_symm_mem(self, inp: torch.Tensor):
|
||||||
|
|||||||
@@ -741,8 +741,8 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
|||||||
return AttentionBlockSize(block_m=16, block_n=16)
|
return AttentionBlockSize(block_m=16, block_n=16)
|
||||||
|
|
||||||
|
|
||||||
def vllm_kernel_override_batch_invariant():
|
def vllm_is_batch_invariant():
|
||||||
env_key = "VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
|
env_key = "VLLM_BATCH_INVARIANT"
|
||||||
is_overridden = False
|
is_overridden = False
|
||||||
val = os.getenv(env_key, "0")
|
val = os.getenv(env_key, "0")
|
||||||
try:
|
try:
|
||||||
@@ -797,7 +797,7 @@ def override_envs_for_invariance():
|
|||||||
|
|
||||||
def init_batch_invariance():
|
def init_batch_invariance():
|
||||||
# this will hit all the csrc overrides as well
|
# this will hit all the csrc overrides as well
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
override_envs_for_invariance()
|
override_envs_for_invariance()
|
||||||
enable_batch_invariant_mode()
|
enable_batch_invariant_mode()
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||||
@@ -841,7 +841,7 @@ def get_moe_configs(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Avoid optimizing for the batch invariant case. Use default config
|
# Avoid optimizing for the batch invariant case. Use default config
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# First look up if an optimized configuration is available in the configs
|
# First look up if an optimized configuration is available in the configs
|
||||||
@@ -976,7 +976,7 @@ def get_default_config(
|
|||||||
dtype: str | None,
|
dtype: str | None,
|
||||||
block_shape: list[int] | None = None,
|
block_shape: list[int] | None = None,
|
||||||
) -> dict[str, int]:
|
) -> dict[str, int]:
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
config = {
|
config = {
|
||||||
"BLOCK_SIZE_M": 64,
|
"BLOCK_SIZE_M": 64,
|
||||||
"BLOCK_SIZE_N": 64,
|
"BLOCK_SIZE_N": 64,
|
||||||
@@ -1136,7 +1136,7 @@ def fused_topk_bias(
|
|||||||
) + e_score_correction_bias.unsqueeze(0)
|
) + e_score_correction_bias.unsqueeze(0)
|
||||||
|
|
||||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||||
use_sorted = vllm_kernel_override_batch_invariant()
|
use_sorted = vllm_is_batch_invariant()
|
||||||
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
topk_indices = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=use_sorted)[1]
|
||||||
topk_weights = scores.gather(1, topk_indices)
|
topk_weights = scores.gather(1, topk_indices)
|
||||||
if renormalize:
|
if renormalize:
|
||||||
@@ -1200,7 +1200,7 @@ def grouped_topk(
|
|||||||
) # [n, n_group]
|
) # [n, n_group]
|
||||||
|
|
||||||
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
# For batch invariance, use sorted=True to ensure deterministic expert selection
|
||||||
use_sorted = vllm_kernel_override_batch_invariant()
|
use_sorted = vllm_is_batch_invariant()
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=use_sorted)[
|
||||||
1
|
1
|
||||||
] # [n, top_k_group]
|
] # [n, top_k_group]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import vllm.envs as envs
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
rms_norm_batch_invariant,
|
rms_norm_batch_invariant,
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
@@ -25,7 +25,7 @@ def rms_norm(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
return rms_norm_batch_invariant(x, weight, variance_epsilon)
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
ops.rms_norm(
|
ops.rms_norm(
|
||||||
@@ -45,7 +45,7 @@ def fused_add_rms_norm(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
return rms_norm_batch_invariant(
|
return rms_norm_batch_invariant(
|
||||||
x + residual, weight, variance_epsilon
|
x + residual, weight, variance_epsilon
|
||||||
), x + residual
|
), x + residual
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
@@ -356,7 +356,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# Disable marlin for rocm
|
# Disable marlin for rocm
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.use_marlin = False
|
self.use_marlin = False
|
||||||
|
|
||||||
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()
|
||||||
@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
# If batch invariant mode is enabled, dequantize and use BF16 compute
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
# Dequantize FP8 weights to BF16
|
# Dequantize FP8 weights to BF16
|
||||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|||||||
from vllm.distributed.parallel_state import get_dcp_group
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@@ -308,7 +308,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
# we only set num_splits when using cuda graphs.
|
# we only set num_splits when using cuda graphs.
|
||||||
max_num_splits = self.max_num_splits
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
max_num_splits = 1
|
max_num_splits = 1
|
||||||
|
|
||||||
def schedule(
|
def schedule(
|
||||||
@@ -484,7 +484,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||||
# Cache the batch invariant result for use in forward passes
|
# Cache the batch invariant result for use in forward passes
|
||||||
self.batch_invariant_enabled = vllm_kernel_override_batch_invariant()
|
self.batch_invariant_enabled = vllm_is_batch_invariant()
|
||||||
|
|
||||||
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
if is_quantized_kv_cache(self.kv_cache_dtype) and not flash_attn_supports_fp8():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -963,7 +963,7 @@ def cascade_attention(
|
|||||||
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
# s_aux is incorporated into prefix_lse inside the GPU kernel,
|
||||||
# enabling its effect during the final attention merge.
|
# enabling its effect during the final attention merge.
|
||||||
s_aux=s_aux,
|
s_aux=s_aux,
|
||||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
|
||||||
@@ -988,7 +988,7 @@ def cascade_attention(
|
|||||||
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
|
||||||
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
|
||||||
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
|
||||||
num_splits=1 if vllm_kernel_override_batch_invariant() else 0,
|
num_splits=1 if vllm_is_batch_invariant() else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Merge prefix and suffix outputs, and store the result in output.
|
# Merge prefix and suffix outputs, and store the result in output.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
@@ -291,7 +291,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.decode_fixed_split_size = 2048
|
self.decode_fixed_split_size = 2048
|
||||||
self.prefill_fixed_split_size = 4096
|
self.prefill_fixed_split_size = 4096
|
||||||
self.disable_split_kv = True
|
self.disable_split_kv = True
|
||||||
@@ -404,7 +404,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
def _get_workspace_buffer(self):
|
def _get_workspace_buffer(self):
|
||||||
if self._workspace_buffer is None:
|
if self._workspace_buffer is None:
|
||||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE_BATCH_INVARIANT
|
||||||
self._workspace_buffer = torch.zeros(
|
self._workspace_buffer = torch.zeros(
|
||||||
buffer_size, dtype=torch.uint8, device=self.device
|
buffer_size, dtype=torch.uint8, device=self.device
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.utils import cdiv, is_torch_equal_or_newer
|
from vllm.utils import cdiv, is_torch_equal_or_newer
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@@ -863,7 +863,7 @@ def get_kernel_options(
|
|||||||
kernel_options: dict[str, int | bool] = {
|
kernel_options: dict[str, int | bool] = {
|
||||||
"FORCE_USE_FLEX_ATTENTION": True,
|
"FORCE_USE_FLEX_ATTENTION": True,
|
||||||
}
|
}
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
kernel_options["BLOCK_M"] = 16
|
kernel_options["BLOCK_M"] = 16
|
||||||
kernel_options["BLOCK_N"] = 16
|
kernel_options["BLOCK_N"] = 16
|
||||||
kernel_options["IS_DIVISIBLE"] = False
|
kernel_options["IS_DIVISIBLE"] = False
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ from vllm.config import VllmConfig, get_current_vllm_config
|
|||||||
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.linear import (
|
from vllm.model_executor.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -1283,7 +1283,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
|
|||||||
# ROCm leverages the upstream flash_attn, which takes a parameter
|
# ROCm leverages the upstream flash_attn, which takes a parameter
|
||||||
# called "return_attn_probs" instead of return_softmax_lse
|
# called "return_attn_probs" instead of return_softmax_lse
|
||||||
kwargs["return_attn_probs"] = return_softmax_lse
|
kwargs["return_attn_probs"] = return_softmax_lse
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
kwargs["num_splits"] = 1
|
kwargs["num_splits"] = 1
|
||||||
|
|
||||||
attn_out = self.flash_attn_varlen_func(
|
attn_out = self.flash_attn_varlen_func(
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from vllm.attention.utils.fa_utils import (
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
@@ -110,7 +110,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# pre-allocated during capture.
|
# pre-allocated during capture.
|
||||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
self.max_num_splits = 1
|
self.max_num_splits = 1
|
||||||
|
|
||||||
def _schedule_decode(
|
def _schedule_decode(
|
||||||
@@ -181,7 +181,7 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
|
|||||||
# we only set num_splits when using cuda graphs.
|
# we only set num_splits when using cuda graphs.
|
||||||
max_num_splits = self.max_num_splits
|
max_num_splits = self.max_num_splits
|
||||||
|
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
max_num_splits = 1
|
max_num_splits = 1
|
||||||
|
|
||||||
metadata = FlashAttnMLADecodeMetadata(
|
metadata = FlashAttnMLADecodeMetadata(
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from vllm.attention.ops.flashmla import (
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.mla.common import (
|
from vllm.v1.attention.backends.mla.common import (
|
||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
@@ -234,7 +234,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
|||||||
|
|
||||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||||
num_splits = attn_metadata.decode.num_splits
|
num_splits = attn_metadata.decode.num_splits
|
||||||
if vllm_kernel_override_batch_invariant():
|
if vllm_is_batch_invariant():
|
||||||
device = q.device
|
device = q.device
|
||||||
dtype = torch.int32
|
dtype = torch.int32
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
|||||||
from vllm.attention.ops.triton_flash_attention import triton_attention
|
from vllm.attention.ops.triton_flash_attention import triton_attention
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.batch_invariant import (
|
from vllm.model_executor.layers.batch_invariant import (
|
||||||
vllm_kernel_override_batch_invariant,
|
vllm_is_batch_invariant,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import HAS_TRITON
|
from vllm.triton_utils import HAS_TRITON
|
||||||
@@ -163,7 +163,7 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
|
||||||
|
|
||||||
# For batch invariance, use only 1 split to ensure deterministic reduction
|
# For batch invariance, use only 1 split to ensure deterministic reduction
|
||||||
num_kv_splits = 1 if vllm_kernel_override_batch_invariant() else 4
|
num_kv_splits = 1 if vllm_is_batch_invariant() else 4
|
||||||
|
|
||||||
# TODO(lucas) Allocate ahead of time
|
# TODO(lucas) Allocate ahead of time
|
||||||
attn_logits = torch.empty(
|
attn_logits = torch.empty(
|
||||||
|
|||||||
Reference in New Issue
Block a user