mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
2 Commits
modular-cu
...
multi_cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c013139aee | ||
|
|
d4aaee48fb |
@@ -14,14 +14,17 @@
|
|||||||
|
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import torch
|
import torch
|
||||||
|
from torch import device, nn
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||||
|
from ...models.controlnet import ControlNetOutput
|
||||||
|
from ...models.modeling_utils import get_parameter_device, get_parameter_dtype
|
||||||
from ...schedulers import KarrasDiffusionSchedulers
|
from ...schedulers import KarrasDiffusionSchedulers
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
PIL_INTERPOLATION,
|
PIL_INTERPOLATION,
|
||||||
@@ -85,6 +88,60 @@ EXAMPLE_DOC_STRING = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MultiControlNet(nn.Module):
|
||||||
|
def __init__(self, controlnets: List[ControlNetModel]):
|
||||||
|
super().__init__()
|
||||||
|
self.nets = nn.ModuleList(controlnets)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> device:
|
||||||
|
"""
|
||||||
|
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||||
|
device).
|
||||||
|
"""
|
||||||
|
return get_parameter_device(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||||
|
"""
|
||||||
|
return get_parameter_dtype(self)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
controlnet_cond: torch.FloatTensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[ControlNetOutput, Tuple]:
|
||||||
|
num_images_per_net = controlnet_cond.shape[0] // len(self.nets)
|
||||||
|
conds = controlnet_cond[None, :].reshape((num_images_per_net, -1) + controlnet_cond.shape[1:])
|
||||||
|
|
||||||
|
down_block_res_samples, mid_block_res_sample = 0
|
||||||
|
for cond, controlnet in zip(conds, self.nets):
|
||||||
|
down, mid = self.controlnet(
|
||||||
|
sample,
|
||||||
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
cond,
|
||||||
|
class_labels,
|
||||||
|
timestep_cond,
|
||||||
|
attention_mask,
|
||||||
|
cross_attention_kwargs,
|
||||||
|
return_dict,
|
||||||
|
)
|
||||||
|
down_block_res_samples += down
|
||||||
|
mid_block_res_sample += mid
|
||||||
|
|
||||||
|
return down_block_res_samples, mid_block_res_sample
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
||||||
r"""
|
r"""
|
||||||
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
|
||||||
@@ -146,6 +203,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(controlnet, (list, tuple)):
|
||||||
|
controlnet = MultiControlNet(controlnet)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@@ -517,7 +577,11 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
elif isinstance(image[0], torch.Tensor):
|
elif isinstance(image[0], torch.Tensor):
|
||||||
image = torch.cat(image, dim=0)
|
image = torch.cat(image, dim=0)
|
||||||
|
|
||||||
image_batch_size = image.shape[0]
|
num_controlnets = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||||
|
image_batch_size = image.shape[0] // num_controlnets
|
||||||
|
|
||||||
|
if image_batch_size != image.shape[0] * num_controlnets:
|
||||||
|
raise ValueError("TODO: Good error message here")
|
||||||
|
|
||||||
if image_batch_size == 1:
|
if image_batch_size == 1:
|
||||||
repeat_by = batch_size
|
repeat_by = batch_size
|
||||||
@@ -716,7 +780,12 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if do_classifier_free_guidance:
|
if do_classifier_free_guidance:
|
||||||
image = torch.cat([image] * 2)
|
num_control = 1 if isinstance(self.controlnet, ControlNetModel) else len(self.controlnet.nets)
|
||||||
|
image = image[None, :].reshape(num_control, -1, *image.shape[1:])
|
||||||
|
|
||||||
|
# only repeat batch size, but not controlnet dim
|
||||||
|
image = image.repeat(1, 2, 1, 1, 1)
|
||||||
|
image = image.reshape((image.shape[:2].numel(),) + image.shape[2:])
|
||||||
|
|
||||||
# 5. Prepare timesteps
|
# 5. Prepare timesteps
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
|||||||
Reference in New Issue
Block a user