[BugFix] Work around graph partition x torch.compile cache issue (#26956)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -337,9 +337,8 @@ def run_model(llama_config, compile_config: CompilationConfig) -> torch.Tensor:
|
||||
def test_toy_llama(
|
||||
backend: str, use_inductor_graph_partition: bool, monkeypatch, tmp_path
|
||||
):
|
||||
# We disable the vLLM compile cache into a new tmp dir for 2 reasons:
|
||||
# We disable the vLLM compile cache into a new tmp dir for 1 reason:
|
||||
# 1. To make sure we can properly track the number of Inductor compilations.
|
||||
# 2. Inductor partitioning does not play nicely with Autograd cache (below)
|
||||
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
|
||||
|
||||
if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
|
||||
@@ -369,15 +368,6 @@ def test_toy_llama(
|
||||
cudagraph_capture_sizes=[1, 2],
|
||||
)
|
||||
|
||||
# FIXME(luka/boyuan): the graph from the previous test case
|
||||
# (no inductor partition) gets cached by AotAutograd so then the
|
||||
# compilation with inductor partitioning incorrectly loads an unpartitioned
|
||||
# graph and never partitions. I think this is a bug with custom inductor
|
||||
# partitioning but does not affect vLLM more generally as vLLM uses its own
|
||||
# cache (which takes inductor partitioning into account).
|
||||
if use_inductor_graph_partition:
|
||||
compile_config_no_split.inductor_compile_config["force_disable_caches"] = True
|
||||
|
||||
compile_config_split = deepcopy(compile_config_no_split)
|
||||
compile_config_split.splitting_ops = ["silly::attention"]
|
||||
|
||||
|
||||
@@ -110,6 +110,27 @@ class PostGradPassManager(CustomGraphPass):
|
||||
self.post_cleanup = PostCleanupPass(config)
|
||||
self.fix_functionalization = FixFunctionalizationPass(config)
|
||||
|
||||
# [HACK: Bug with Inductor graph partition and torch.compile cache]
|
||||
# In PyTorch 2.9, torch.compile has a bug where the graph
|
||||
# partition is not taken into account during caching.
|
||||
# Because vLLM's Mode.VLLM_COMPILE is the only mode that uses
|
||||
# Inductor graph partition, and VLLM_COMPILE implies there
|
||||
# is a PostGradPassManager, we put the list of operators to graph
|
||||
# partition into the PostGradPassManager's uuid (which
|
||||
# then gets incorporated into Inductor's FX graph cache key).
|
||||
# Remove this hack whenever torch.compile fixes it.
|
||||
|
||||
# This is the list of operators that vLLM asks Inductor to split.
|
||||
self.inductor_splitting_ops = []
|
||||
if (
|
||||
config.compilation_config.use_inductor_graph_partition
|
||||
and config.compilation_config.splitting_ops is not None
|
||||
):
|
||||
# Sort them so we're not dependent on the ordering.
|
||||
self.inductor_splitting_ops = sorted(
|
||||
config.compilation_config.splitting_ops
|
||||
)
|
||||
|
||||
def add(self, pass_: InductorPass):
|
||||
assert isinstance(pass_, InductorPass)
|
||||
self.passes.append(pass_)
|
||||
@@ -120,8 +141,16 @@ class PostGradPassManager(CustomGraphPass):
|
||||
affects compilation caching. Its uuid depends on the UUIDs of all
|
||||
dependent passes and the pass config. See InductorPass for more info.
|
||||
"""
|
||||
state = {"pass_config": self.pass_config.uuid(), "passes": []}
|
||||
state = {
|
||||
"pass_config": self.pass_config.uuid(),
|
||||
"passes": [],
|
||||
"inductor_splitting_ops": [],
|
||||
}
|
||||
for pass_ in self.passes:
|
||||
state["passes"].append(pass_.uuid())
|
||||
state["passes"].append(self.fix_functionalization.uuid())
|
||||
|
||||
# See [HACK: Bug with Inductor graph partition and torch.compile cache]
|
||||
state["inductor_splitting_ops"].extend(self.inductor_splitting_ops)
|
||||
|
||||
return InductorPass.hash_dict(state)
|
||||
|
||||
Reference in New Issue
Block a user