[TPU] add tpu_inference (#27277)

Signed-off-by: Johnny Yang <johnnyyang@google.com>
This commit is contained in:
Johnny Yang
2025-11-26 14:46:36 -08:00
committed by GitHub
parent 56539cddac
commit ba1fcd84a7
4 changed files with 5 additions and 13 deletions

View File

@@ -12,6 +12,4 @@ ray[data]
setuptools==78.1.0
nixl==0.3.0
tpu_info==0.4.0
# Install torch_xla
torch_xla[tpu, pallas]==2.8.0
tpu-inference==0.11.1

View File

@@ -97,11 +97,3 @@ class TpuCommunicator(DeviceCommunicatorBase):
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
if USE_TPU_INFERENCE:
from tpu_inference.distributed.device_communicators import (
TpuCommunicator as TpuInferenceCommunicator,
)
TpuCommunicator = TpuInferenceCommunicator # type: ignore

View File

@@ -267,7 +267,9 @@ class TpuPlatform(Platform):
try:
from tpu_inference.platforms import TpuPlatform as TpuInferencePlatform
from tpu_inference.platforms.tpu_platforms import (
TpuPlatform as TpuInferencePlatform,
)
TpuPlatform = TpuInferencePlatform # type: ignore
USE_TPU_INFERENCE = True

View File

@@ -346,6 +346,6 @@ class TPUWorker:
if USE_TPU_INFERENCE:
from tpu_inference.worker import TPUWorker as TpuInferenceWorker
from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker
TPUWorker = TpuInferenceWorker # type: ignore