Compare commits

...

1 Commits

Author SHA1 Message Date
Patrick von Platen
6b2d9e6acd [Draft] MultiControlNet 2023-03-09 10:46:24 +00:00

View File

@@ -146,6 +146,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
if isinstance(controlnet, list):
controlnet = torch.nn.ModuleList(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -738,7 +741,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
# 8. Prepare controlnets for potential multi-controlnet case
controlnets = self.controlnet if isinstance(self.controlnet, torch.nn.ModuleList) else [self.controlnet]
images_per_controlnet = image.shape[0] // len(controlnets)
if images_per_controlnet * len(controlnets) != image.shape[0]:
raise ValueError(f"You have passed {len(controlnets)} ControlNet models, but {image.shape[0]} conditioned images. Please make sure to pass `n` x {len(controlnets)} images to generate `n` output images.")
control_images = image[None, :].reshape(len(controlnets), images_per_controlnet, *image.shape[1:])
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -746,19 +758,23 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
return_dict=False,
)
control_down_block_res = control_mid_block_res = 0
for image, controlnet in zip(control_images, controlnets):
down_res, mid_res = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=image,
return_dict=False,
)
control_down_block_res += down_res
control_mid_block_res += mid_res
down_block_res_samples = [
control_down_block_res = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
for down_block_res_sample in control_down_block_res
]
mid_block_res_sample *= controlnet_conditioning_scale
control_mid_block_res *= controlnet_conditioning_scale
# predict the noise residual
noise_pred = self.unet(
@@ -766,8 +782,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
down_block_additional_residuals=control_down_block_res,
mid_block_additional_residual=control_mid_block_res,
).sample
# perform guidance