Compare commits

...

1 Commits

Author SHA1 Message Date
Pedro Cuenca
38e4f1e014 Manually move the UNet to cuda/cpu. 2023-02-07 10:54:30 +01:00

View File

@@ -193,7 +193,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
device = torch.device(f"cuda:{gpu_id}")
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
for cpu_offloaded_model in [self.text_encoder, self.vae]:
cpu_offload(cpu_offloaded_model, device)
if self.safety_checker is not None:
@@ -206,9 +206,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
`pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
hooks.
"""
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
if self.device != torch.device("meta") or not hasattr(self.vae, "_hf_hook"):
return self.device
for module in self.unet.modules():
for module in self.vae.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
@@ -600,6 +600,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
offloaded_device = self.unet.device
self.unet.to(self._execution_device)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -628,6 +630,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
self.unet.to(offloaded_device)
if output_type == "latent":
image = latents