[ptxla] fix pytorch xla inference on TPUs. (#13463)

Co-authored-by: Juan Acevedo <jfacevedo@google.com>
This commit is contained in:
Juan Acevedo
2026-04-13 20:51:26 -07:00
committed by GitHub
parent 5063aa5566
commit 26bb7fa0cb

View File

@@ -877,10 +877,7 @@ class FluxPipeline(
self.scheduler.config.get("max_shift", 1.15),
)
if XLA_AVAILABLE:
timestep_device = "cpu"
else:
timestep_device = device
timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,