Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
c013139aee [Draft] Upload 2023-03-09 11:32:01 +00:00
Patrick von Platen
d4aaee48fb add draft 2023-03-09 11:25:58 +00:00

View File

@@ -14,14 +14,17 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
from torch import device, nn
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.controlnet import ControlNetOutput
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
@@ -85,6 +88,60 @@ EXAMPLE_DOC_STRING = """
"""
class MultiControlNet(nn.Module):
def __init__(self, controlnets: List[ControlNetModel]):
super().__init__()
self.nets = nn.ModuleList(controlnets)
@property
def device(self) -> device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return get_parameter_device(self)
@property
def dtype(self) -> torch.dtype:
"""
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
"""
return get_parameter_dtype(self)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: torch.FloatTensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[ControlNetOutput, Tuple]:
num_images_per_net = controlnet_cond.shape[0] // len(self.nets)
conds = controlnet_cond[None, :].reshape((num_images_per_net, -1) + controlnet_cond.shape[1:])
down_block_res_samples, mid_block_res_sample = 0
for cond, controlnet in zip(conds, self.nets):
down, mid = self.controlnet(
sample,
timestep,
encoder_hidden_states,
cond,
class_labels,
timestep_cond,
attention_mask,
cross_attention_kwargs,
return_dict,
)
down_block_res_samples += down
mid_block_res_sample += mid
return down_block_res_samples, mid_block_res_sample
class StableDiffusionControlNetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
@@ -146,6 +203,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNet(controlnet)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
@@ -517,7 +577,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
image_batch_size = image.shape[0]
num_controlnets = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
image_batch_size = image.shape[0] // num_controlnets
if image_batch_size != image.shape[0] * num_controlnets:
raise ValueError("TODO: Good error message here")
if image_batch_size == 1:
repeat_by = batch_size
@@ -716,7 +780,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
num_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
image = image[None, :].reshape(num_control, -1, *image.shape[1:])
# only repeat batch size, but not controlnet dim
image = image.repeat(1, 2, 1, 1, 1)
image = image.reshape((image.shape[:2].numel(),) + image.shape[2:])
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)