mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 15:04:47 +08:00
Do not guard during noop elimination pass (#30095)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from collections.abc import Iterable
|
|||||||
|
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from torch import SymInt
|
from torch import SymInt
|
||||||
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
@@ -116,12 +117,7 @@ class NoOpEliminationPass(VllmInductorPass):
|
|||||||
2. The dimensions both correspond to the same SymInt
|
2. The dimensions both correspond to the same SymInt
|
||||||
"""
|
"""
|
||||||
# Case 1
|
# Case 1
|
||||||
if isinstance(i_dim, int) and isinstance(dim, int):
|
return statically_known_true(dim == i_dim)
|
||||||
return dim == i_dim
|
|
||||||
# Case 2
|
|
||||||
if isinstance(i_dim, SymInt) and isinstance(dim, SymInt):
|
|
||||||
return dim == i_dim
|
|
||||||
return False
|
|
||||||
|
|
||||||
def all_dims_equivalent(
|
def all_dims_equivalent(
|
||||||
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
self, dims: Iterable[int | SymInt], i_dims: Iterable[int | SymInt]
|
||||||
|
|||||||
Reference in New Issue
Block a user