mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
torchao-lo
...
multi_cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c013139aee | ||
|
|
d4aaee48fb |
@@ -14,14 +14,17 @@
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from torch import device, nn
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.controlnet import ControlNetOutput
|
||||
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
@@ -85,6 +88,60 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class MultiControlNet(nn.Module):
|
||||
def __init__(self, controlnets: List[ControlNetModel]):
|
||||
super().__init__()
|
||||
self.nets = nn.ModuleList(controlnets)
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
"""
|
||||
return get_parameter_device(self)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""
|
||||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
controlnet_cond: torch.FloatTensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[ControlNetOutput, Tuple]:
|
||||
num_images_per_net = controlnet_cond.shape[0] // len(self.nets)
|
||||
conds = controlnet_cond[None, :].reshape((num_images_per_net, -1) + controlnet_cond.shape[1:])
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = 0
|
||||
for cond, controlnet in zip(conds, self.nets):
|
||||
down, mid = self.controlnet(
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
cond,
|
||||
class_labels,
|
||||
timestep_cond,
|
||||
attention_mask,
|
||||
cross_attention_kwargs,
|
||||
return_dict,
|
||||
)
|
||||
down_block_res_samples += down
|
||||
mid_block_res_sample += mid
|
||||
|
||||
return down_block_res_samples, mid_block_res_sample
|
||||
|
||||
|
||||
class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||
@@ -146,6 +203,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
if isinstance(controlnet, (list, tuple)):
|
||||
controlnet = MultiControlNet(controlnet)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
@@ -517,7 +577,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
elif isinstance(image[0], torch.Tensor):
|
||||
image = torch.cat(image, dim=0)
|
||||
|
||||
image_batch_size = image.shape[0]
|
||||
num_controlnets = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||
image_batch_size = image.shape[0] // num_controlnets
|
||||
|
||||
if image_batch_size != image.shape[0] * num_controlnets:
|
||||
raise ValueError("TODO: Good error message here")
|
||||
|
||||
if image_batch_size == 1:
|
||||
repeat_by = batch_size
|
||||
@@ -716,7 +780,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
image = torch.cat([image] * 2)
|
||||
num_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||
image = image[None, :].reshape(num_control, -1, *image.shape[1:])
|
||||
|
||||
# only repeat batch size, but not controlnet dim
|
||||
image = image.repeat(1, 2, 1, 1, 1)
|
||||
image = image.reshape((image.shape[:2].numel(),) + image.shape[2:])
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
Reference in New Issue
Block a user