|
|
|
|
@@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
import numpy as np
|
|
|
|
|
import PIL.Image
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from transformers import (
|
|
|
|
|
CLIPImageProcessor,
|
|
|
|
|
CLIPTextModel,
|
|
|
|
|
@@ -35,7 +34,13 @@ from ...loaders import (
|
|
|
|
|
StableDiffusionXLLoraLoaderMixin,
|
|
|
|
|
TextualInversionLoaderMixin,
|
|
|
|
|
)
|
|
|
|
|
from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
|
|
|
|
|
from ...models import (
|
|
|
|
|
AutoencoderKL,
|
|
|
|
|
ControlNetUnionModel,
|
|
|
|
|
ImageProjection,
|
|
|
|
|
MultiControlNetUnionModel,
|
|
|
|
|
UNet2DConditionModel,
|
|
|
|
|
)
|
|
|
|
|
from ...models.attention_processor import (
|
|
|
|
|
AttnProcessor2_0,
|
|
|
|
|
XFormersAttnProcessor,
|
|
|
|
|
@@ -230,7 +235,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
tokenizer: CLIPTokenizer,
|
|
|
|
|
tokenizer_2: CLIPTokenizer,
|
|
|
|
|
unet: UNet2DConditionModel,
|
|
|
|
|
controlnet: ControlNetUnionModel,
|
|
|
|
|
controlnet: Union[
|
|
|
|
|
ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
|
|
|
|
|
],
|
|
|
|
|
scheduler: KarrasDiffusionSchedulers,
|
|
|
|
|
requires_aesthetics_score: bool = False,
|
|
|
|
|
force_zeros_for_empty_prompt: bool = True,
|
|
|
|
|
@@ -240,8 +247,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
):
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
if not isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
|
|
|
|
|
if isinstance(controlnet, (list, tuple)):
|
|
|
|
|
controlnet = MultiControlNetUnionModel(controlnet)
|
|
|
|
|
|
|
|
|
|
self.register_modules(
|
|
|
|
|
vae=vae,
|
|
|
|
|
@@ -660,6 +667,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
controlnet_conditioning_scale=1.0,
|
|
|
|
|
control_guidance_start=0.0,
|
|
|
|
|
control_guidance_end=1.0,
|
|
|
|
|
control_mode=None,
|
|
|
|
|
callback_on_step_end_tensor_inputs=None,
|
|
|
|
|
padding_mask_crop=None,
|
|
|
|
|
):
|
|
|
|
|
@@ -747,25 +755,34 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Check `image`
|
|
|
|
|
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
|
|
|
|
|
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
|
|
|
|
|
)
|
|
|
|
|
if (
|
|
|
|
|
isinstance(self.controlnet, ControlNetModel)
|
|
|
|
|
or is_compiled
|
|
|
|
|
and isinstance(self.controlnet._orig_mod, ControlNetModel)
|
|
|
|
|
):
|
|
|
|
|
self.check_image(image, prompt, prompt_embeds)
|
|
|
|
|
elif (
|
|
|
|
|
isinstance(self.controlnet, ControlNetUnionModel)
|
|
|
|
|
or is_compiled
|
|
|
|
|
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
|
|
|
|
|
):
|
|
|
|
|
self.check_image(image, prompt, prompt_embeds)
|
|
|
|
|
# `prompt` needs more sophisticated handling when there are multiple
|
|
|
|
|
# conditionings.
|
|
|
|
|
if isinstance(self.controlnet, MultiControlNetUnionModel):
|
|
|
|
|
if isinstance(prompt, list):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
|
|
|
|
|
" prompts. The conditionings will be fixed across the prompts."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
assert False
|
|
|
|
|
# Check `image`
|
|
|
|
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
|
|
|
|
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
for image_ in image:
|
|
|
|
|
self.check_image(image_, prompt, prompt_embeds)
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
if not isinstance(image, list):
|
|
|
|
|
raise TypeError("For multiple controlnets: `image` must be type `list`")
|
|
|
|
|
elif not all(isinstance(i, list) for i in image):
|
|
|
|
|
raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
|
|
|
|
|
elif len(image) != len(self.controlnet.nets):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for images_ in image:
|
|
|
|
|
for image_ in images_:
|
|
|
|
|
self.check_image(image_, prompt, prompt_embeds)
|
|
|
|
|
|
|
|
|
|
if not isinstance(control_guidance_start, (tuple, list)):
|
|
|
|
|
control_guidance_start = [control_guidance_start]
|
|
|
|
|
@@ -778,6 +795,12 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
if len(control_guidance_start) != len(self.controlnet.nets):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for start, end in zip(control_guidance_start, control_guidance_end):
|
|
|
|
|
if start >= end:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
@@ -788,6 +811,28 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
if end > 1.0:
|
|
|
|
|
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
|
|
|
|
|
|
|
|
|
# Check `control_mode`
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
if max(control_mode) >= controlnet.config.num_control_type:
|
|
|
|
|
raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
|
|
|
|
|
if max(_control_mode) >= _controlnet.config.num_control_type:
|
|
|
|
|
raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
|
|
|
|
|
|
|
|
|
|
# Equal number of `image` and `control_mode` elements
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
if len(image) != len(control_mode):
|
|
|
|
|
raise ValueError("Expected len(control_image) == len(control_mode)")
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
if not all(isinstance(i, list) for i in control_mode):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
|
|
|
|
|
raise ValueError("Expected len(control_image) == len(control_mode)")
|
|
|
|
|
|
|
|
|
|
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
|
|
|
|
|
@@ -1117,7 +1162,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
prompt_2: Optional[Union[str, List[str]]] = None,
|
|
|
|
|
image: PipelineImageInput = None,
|
|
|
|
|
mask_image: PipelineImageInput = None,
|
|
|
|
|
control_image: PipelineImageInput = None,
|
|
|
|
|
control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
|
|
|
|
|
height: Optional[int] = None,
|
|
|
|
|
width: Optional[int] = None,
|
|
|
|
|
padding_mask_crop: Optional[int] = None,
|
|
|
|
|
@@ -1145,7 +1190,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
guess_mode: bool = False,
|
|
|
|
|
control_guidance_start: Union[float, List[float]] = 0.0,
|
|
|
|
|
control_guidance_end: Union[float, List[float]] = 1.0,
|
|
|
|
|
control_mode: Optional[Union[int, List[int]]] = None,
|
|
|
|
|
control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
|
|
|
|
|
guidance_rescale: float = 0.0,
|
|
|
|
|
original_size: Tuple[int, int] = None,
|
|
|
|
|
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
|
|
|
|
@@ -1177,6 +1222,13 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
|
|
|
|
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
|
|
|
|
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
|
|
|
|
control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
|
|
|
|
|
The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
|
|
|
|
|
specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
|
|
|
|
|
as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
|
|
|
|
|
width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
|
|
|
|
|
images must be passed as a list such that each element of the list can be correctly batched for input
|
|
|
|
|
to a single ControlNet.
|
|
|
|
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
|
|
|
The height in pixels of the generated image.
|
|
|
|
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
|
|
|
@@ -1269,6 +1321,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
|
|
|
`self.processor` in
|
|
|
|
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
|
|
|
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
|
|
|
The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
|
|
|
|
|
to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
|
|
|
|
|
the corresponding scale as a list.
|
|
|
|
|
guess_mode (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
The ControlNet encoder tries to recognize the content of the input image even if you remove all
|
|
|
|
|
prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
|
|
|
|
|
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
|
|
|
|
The percentage of total steps at which the ControlNet starts applying.
|
|
|
|
|
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
|
|
|
The percentage of total steps at which the ControlNet stops applying.
|
|
|
|
|
control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
|
|
|
|
|
The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
|
|
|
|
|
available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
|
|
|
|
|
where each ControlNet should have its corresponding control mode list. Should reflect the order of
|
|
|
|
|
conditions in control_image.
|
|
|
|
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
|
|
|
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
|
|
|
|
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
|
|
|
|
@@ -1333,22 +1401,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
|
|
|
|
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
|
|
|
|
|
|
|
|
|
# align format for control guidance
|
|
|
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
|
|
|
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
|
|
|
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
|
|
|
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
|
|
|
|
|
|
|
|
|
# # 0.0 Default height and width to unet
|
|
|
|
|
# height = height or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
|
# width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
|
|
|
|
|
|
|
# 0.1 align format for control guidance
|
|
|
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
|
|
|
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
|
|
|
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
|
|
|
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
|
|
|
|
|
|
|
|
|
if not isinstance(control_image, list):
|
|
|
|
|
control_image = [control_image]
|
|
|
|
|
else:
|
|
|
|
|
@@ -1357,40 +1409,59 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
if not isinstance(control_mode, list):
|
|
|
|
|
control_mode = [control_mode]
|
|
|
|
|
|
|
|
|
|
if len(control_image) != len(control_mode):
|
|
|
|
|
raise ValueError("Expected len(control_image) == len(control_type)")
|
|
|
|
|
if isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
control_image = [[item] for item in control_image]
|
|
|
|
|
control_mode = [[item] for item in control_mode]
|
|
|
|
|
|
|
|
|
|
num_control_type = controlnet.config.num_control_type
|
|
|
|
|
|
|
|
|
|
# 1. Check inputs
|
|
|
|
|
control_type = [0 for _ in range(num_control_type)]
|
|
|
|
|
for _image, control_idx in zip(control_image, control_mode):
|
|
|
|
|
control_type[control_idx] = 1
|
|
|
|
|
self.check_inputs(
|
|
|
|
|
prompt,
|
|
|
|
|
prompt_2,
|
|
|
|
|
_image,
|
|
|
|
|
mask_image,
|
|
|
|
|
strength,
|
|
|
|
|
num_inference_steps,
|
|
|
|
|
callback_steps,
|
|
|
|
|
output_type,
|
|
|
|
|
negative_prompt,
|
|
|
|
|
negative_prompt_2,
|
|
|
|
|
prompt_embeds,
|
|
|
|
|
negative_prompt_embeds,
|
|
|
|
|
ip_adapter_image,
|
|
|
|
|
ip_adapter_image_embeds,
|
|
|
|
|
pooled_prompt_embeds,
|
|
|
|
|
negative_pooled_prompt_embeds,
|
|
|
|
|
controlnet_conditioning_scale,
|
|
|
|
|
control_guidance_start,
|
|
|
|
|
control_guidance_end,
|
|
|
|
|
callback_on_step_end_tensor_inputs,
|
|
|
|
|
padding_mask_crop,
|
|
|
|
|
# align format for control guidance
|
|
|
|
|
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
|
|
|
|
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
|
|
|
|
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
|
|
|
|
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
|
|
|
|
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
|
|
|
|
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
|
|
|
|
control_guidance_start, control_guidance_end = (
|
|
|
|
|
mult * [control_guidance_start],
|
|
|
|
|
mult * [control_guidance_end],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
control_type = torch.Tensor(control_type)
|
|
|
|
|
if isinstance(controlnet_conditioning_scale, float):
|
|
|
|
|
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
|
|
|
|
|
controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
|
|
|
|
|
|
|
|
|
|
# 1. Check inputs
|
|
|
|
|
self.check_inputs(
|
|
|
|
|
prompt,
|
|
|
|
|
prompt_2,
|
|
|
|
|
control_image,
|
|
|
|
|
mask_image,
|
|
|
|
|
strength,
|
|
|
|
|
num_inference_steps,
|
|
|
|
|
callback_steps,
|
|
|
|
|
output_type,
|
|
|
|
|
negative_prompt,
|
|
|
|
|
negative_prompt_2,
|
|
|
|
|
prompt_embeds,
|
|
|
|
|
negative_prompt_embeds,
|
|
|
|
|
ip_adapter_image,
|
|
|
|
|
ip_adapter_image_embeds,
|
|
|
|
|
pooled_prompt_embeds,
|
|
|
|
|
negative_pooled_prompt_embeds,
|
|
|
|
|
controlnet_conditioning_scale,
|
|
|
|
|
control_guidance_start,
|
|
|
|
|
control_guidance_end,
|
|
|
|
|
control_mode,
|
|
|
|
|
callback_on_step_end_tensor_inputs,
|
|
|
|
|
padding_mask_crop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
control_type = [
|
|
|
|
|
torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
|
|
|
|
|
for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
self._guidance_scale = guidance_scale
|
|
|
|
|
self._clip_skip = clip_skip
|
|
|
|
|
@@ -1483,21 +1554,55 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
init_image = init_image.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
# 5.2 Prepare control images
|
|
|
|
|
for idx, _ in enumerate(control_image):
|
|
|
|
|
control_image[idx] = self.prepare_control_image(
|
|
|
|
|
image=control_image[idx],
|
|
|
|
|
width=width,
|
|
|
|
|
height=height,
|
|
|
|
|
batch_size=batch_size * num_images_per_prompt,
|
|
|
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=controlnet.dtype,
|
|
|
|
|
crops_coords=crops_coords,
|
|
|
|
|
resize_mode=resize_mode,
|
|
|
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
|
|
|
guess_mode=guess_mode,
|
|
|
|
|
)
|
|
|
|
|
height, width = control_image[idx].shape[-2:]
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
control_images = []
|
|
|
|
|
|
|
|
|
|
for image_ in control_image:
|
|
|
|
|
image_ = self.prepare_control_image(
|
|
|
|
|
image=image_,
|
|
|
|
|
width=width,
|
|
|
|
|
height=height,
|
|
|
|
|
batch_size=batch_size * num_images_per_prompt,
|
|
|
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=controlnet.dtype,
|
|
|
|
|
crops_coords=crops_coords,
|
|
|
|
|
resize_mode=resize_mode,
|
|
|
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
|
|
|
guess_mode=guess_mode,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
control_images.append(image_)
|
|
|
|
|
|
|
|
|
|
control_image = control_images
|
|
|
|
|
height, width = control_image[0].shape[-2:]
|
|
|
|
|
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
control_images = []
|
|
|
|
|
|
|
|
|
|
for control_image_ in control_image:
|
|
|
|
|
images = []
|
|
|
|
|
|
|
|
|
|
for image_ in control_image_:
|
|
|
|
|
image_ = self.prepare_control_image(
|
|
|
|
|
image=image_,
|
|
|
|
|
width=width,
|
|
|
|
|
height=height,
|
|
|
|
|
batch_size=batch_size * num_images_per_prompt,
|
|
|
|
|
num_images_per_prompt=num_images_per_prompt,
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=controlnet.dtype,
|
|
|
|
|
crops_coords=crops_coords,
|
|
|
|
|
resize_mode=resize_mode,
|
|
|
|
|
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
|
|
|
|
guess_mode=guess_mode,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
images.append(image_)
|
|
|
|
|
control_images.append(images)
|
|
|
|
|
|
|
|
|
|
control_image = control_images
|
|
|
|
|
height, width = control_image[0][0].shape[-2:]
|
|
|
|
|
|
|
|
|
|
# 5.3 Prepare mask
|
|
|
|
|
mask = self.mask_processor.preprocess(
|
|
|
|
|
@@ -1559,10 +1664,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
# 8.2 Create tensor stating which controlnets to keep
|
|
|
|
|
controlnet_keep = []
|
|
|
|
|
for i in range(len(timesteps)):
|
|
|
|
|
controlnet_keep.append(
|
|
|
|
|
1.0
|
|
|
|
|
- float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
|
|
|
|
|
)
|
|
|
|
|
keeps = [
|
|
|
|
|
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
|
|
|
|
for s, e in zip(control_guidance_start, control_guidance_end)
|
|
|
|
|
]
|
|
|
|
|
controlnet_keep.append(keeps)
|
|
|
|
|
|
|
|
|
|
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
|
|
|
height, width = latents.shape[-2:]
|
|
|
|
|
@@ -1627,11 +1733,24 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
|
|
|
|
|
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
|
|
|
|
|
timesteps = timesteps[:num_inference_steps]
|
|
|
|
|
|
|
|
|
|
control_type = (
|
|
|
|
|
control_type.reshape(1, -1)
|
|
|
|
|
.to(device, dtype=prompt_embeds.dtype)
|
|
|
|
|
.repeat(batch_size * num_images_per_prompt * 2, 1)
|
|
|
|
|
control_type_repeat_factor = (
|
|
|
|
|
batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(controlnet, ControlNetUnionModel):
|
|
|
|
|
control_type = (
|
|
|
|
|
control_type.reshape(1, -1)
|
|
|
|
|
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
|
|
|
|
.repeat(control_type_repeat_factor, 1)
|
|
|
|
|
)
|
|
|
|
|
elif isinstance(controlnet, MultiControlNetUnionModel):
|
|
|
|
|
control_type = [
|
|
|
|
|
_control_type.reshape(1, -1)
|
|
|
|
|
.to(self._execution_device, dtype=prompt_embeds.dtype)
|
|
|
|
|
.repeat(control_type_repeat_factor, 1)
|
|
|
|
|
for _control_type in control_type
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
|
|
for i, t in enumerate(timesteps):
|
|
|
|
|
if self.interrupt:
|
|
|
|
|
|