mirror of
https://github.com/vllm-project/vllm.git
synced 2025-12-06 06:53:12 +08:00
[TPU] add tpu_inference (#27277)
Signed-off-by: Johnny Yang <johnnyyang@google.com>
This commit is contained in:
@@ -12,6 +12,4 @@ ray[data]
|
||||
setuptools==78.1.0
|
||||
nixl==0.3.0
|
||||
tpu_info==0.4.0
|
||||
|
||||
# Install torch_xla
|
||||
torch_xla[tpu, pallas]==2.8.0
|
||||
tpu-inference==0.11.1
|
||||
|
||||
@@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase):
|
||||
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
assert dim == -1, "TPUs only support dim=-1 for all-gather."
|
||||
return xm.all_gather(input_, dim=dim)
|
||||
|
||||
|
||||
if USE_TPU_INFERENCE:
|
||||
from tpu_inference.distributed.device_communicators import (
|
||||
TpuCommunicator as TpuInferenceCommunicator,
|
||||
)
|
||||
|
||||
TpuCommunicator = TpuInferenceCommunicator # type: ignore
|
||||
|
||||
@@ -267,7 +267,9 @@ class TpuPlatform(Platform):
|
||||
|
||||
|
||||
try:
|
||||
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
|
||||
from tpu_inference.platforms.tpu_platforms import (
|
||||
TpuPlatform as TpuInferencePlatform,
|
||||
)
|
||||
|
||||
TpuPlatform = TpuInferencePlatform # type: ignore
|
||||
USE_TPU_INFERENCE = True
|
||||
|
||||
@@ -346,6 +346,6 @@ class TPUWorker:
|
||||
|
||||
|
||||
if USE_TPU_INFERENCE:
|
||||
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
|
||||
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
|
||||
|
||||
TPUWorker = TpuInferenceWorker # type: ignore
|
||||
|
||||
Reference in New Issue
Block a user