mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 22:44:38 +08:00
Compare commits
1 Commits
animatedif
...
multi_cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b2d9e6acd |
@@ -146,6 +146,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(controlnet, list):
|
||||||
|
controlnet = torch.nn.ModuleList(controlnet)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -738,7 +741,16 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
# 8. Denoising loop
|
# 8. Prepare controlnets for potential multi-controlnet case
|
||||||
|
controlnets = self.controlnet if isinstance(self.controlnet, torch.nn.ModuleList) else [self.controlnet]
|
||||||
|
images_per_controlnet = image.shape[0] // len(controlnets)
|
||||||
|
|
||||||
|
if images_per_controlnet * len(controlnets) != image.shape[0]:
|
||||||
|
raise ValueError(f"You have passed {len(controlnets)} ControlNet models, but {image.shape[0]} conditioned images. Please make sure to pass `n` x {len(controlnets)} images to generate `n` output images.")
|
||||||
|
|
||||||
|
control_images = image[None, :].reshape(len(controlnets), images_per_controlnet, *image.shape[1:])
|
||||||
|
|
||||||
|
# 9. Denoising loop
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
for i, t in enumerate(timesteps):
|
for i, t in enumerate(timesteps):
|
||||||
@@ -746,19 +758,23 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||||
|
|
||||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
control_down_block_res = control_mid_block_res = 0
|
||||||
latent_model_input,
|
for image, controlnet in zip(control_images, controlnets):
|
||||||
t,
|
down_res, mid_res = self.controlnet(
|
||||||
encoder_hidden_states=prompt_embeds,
|
latent_model_input,
|
||||||
controlnet_cond=image,
|
t,
|
||||||
return_dict=False,
|
encoder_hidden_states=prompt_embeds,
|
||||||
)
|
controlnet_cond=image,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
control_down_block_res += down_res
|
||||||
|
control_mid_block_res += mid_res
|
||||||
|
|
||||||
down_block_res_samples = [
|
control_down_block_res = [
|
||||||
down_block_res_sample * controlnet_conditioning_scale
|
down_block_res_sample * controlnet_conditioning_scale
|
||||||
for down_block_res_sample in down_block_res_samples
|
for down_block_res_sample in control_down_block_res
|
||||||
]
|
]
|
||||||
mid_block_res_sample *= controlnet_conditioning_scale
|
control_mid_block_res *= controlnet_conditioning_scale
|
||||||
|
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_pred = self.unet(
|
noise_pred = self.unet(
|
||||||
@@ -766,8 +782,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
t,
|
t,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=down_block_res_samples,
|
down_block_additional_residuals=control_down_block_res,
|
||||||
mid_block_additional_residual=mid_block_res_sample,
|
mid_block_additional_residual=control_mid_block_res,
|
||||||
).sample
|
).sample
|
||||||
|
|
||||||
# perform guidance
|
# perform guidance
|
||||||
|
|||||||
Reference in New Issue
Block a user