[@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913)

* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
This commit is contained in:
Patrick von Platen
2023-11-24 17:46:00 +01:00
committed by GitHub
parent e5f232f76b
commit b978334d71
17 changed files with 2110 additions and 1 deletions

View File

@@ -278,6 +278,8 @@
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion

View File

@@ -0,0 +1,24 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 3
TODO
## Kandinsky3Pipeline
[[autodoc]] Kandinsky3Pipeline
- all
- __call__
## Kandinsky3Img2ImgPipeline
[[autodoc]] Kandinsky3Img2ImgPipeline
- all
- __call__

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)

View File

@@ -79,6 +79,7 @@ else:
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
"Kandinsky3UNet",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -214,6 +215,8 @@ else:
"IFPipeline",
"IFSuperResolutionPipeline",
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Kandinsky3UNet,
ModelMixin,
MotionAdapter,
MultiAdapter,
@@ -560,6 +564,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IFPipeline,
IFSuperResolutionPipeline,
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,

View File

@@ -36,6 +36,7 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"]
@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel

View File

@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import einsum, nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
@@ -2219,6 +2219,44 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""
@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
@@ -2244,6 +2282,7 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
)
AttentionProcessor = Union[

View File

@@ -0,0 +1,589 @@
import math
from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, Kandi3AttnProcessor
from .embeddings import TimestepEmbedding
from .modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class Kandinsky3UNetOutput(BaseOutput):
sample: torch.FloatTensor = None
# TODO(Yiyi): This class needs to be removed
def set_default_item(condition, item_1, item_2=None):
if condition:
return item_1
else:
return item_2
# TODO(Yiyi): This class needs to be removed
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
if condition:
return layer_1(*args_1, **kwargs_1)
else:
return layer_2(*args_2, **kwargs_2)
# TODO(Yiyi): This class should be removed and be replaced by Timesteps
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, type_tensor=None):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
emb = x[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class Kandinsky3EncoderProj(nn.Module):
def __init__(self, encoder_hid_dim, cross_attention_dim):
super().__init__()
self.projection_linear = nn.Linear(encoder_hid_dim, cross_attention_dim, bias=False)
self.projection_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, x):
x = self.projection_linear(x)
x = self.projection_norm(x)
return x
class Kandinsky3UNet(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int = 4,
time_embedding_dim: int = 1536,
groups: int = 32,
attention_head_dim: int = 64,
layers_per_block: Union[int, Tuple[int]] = 3,
block_out_channels: Tuple[int] = (384, 768, 1536, 3072),
cross_attention_dim: Union[int, Tuple[int]] = 4096,
encoder_hid_dim: int = 4096,
):
super().__init__()
# TOOD(Yiyi): Give better name and put into config for the following 4 parameters
expansion_ratio = 4
compression_ratio = 2
add_cross_attention = (False, True, True, True)
add_self_attention = (False, True, True, True)
out_channels = in_channels
init_channels = block_out_channels[0] // 2
# TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same
# self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
self.time_proj = SinusoidalPosEmb(init_channels)
self.time_embedding = TimestepEmbedding(
init_channels,
time_embedding_dim,
)
self.add_time_condition = Kandinsky3AttentionPooling(
time_embedding_dim, cross_attention_dim, attention_head_dim
)
self.conv_in = nn.Conv2d(in_channels, init_channels, kernel_size=3, padding=1)
self.encoder_hid_proj = Kandinsky3EncoderProj(encoder_hid_dim, cross_attention_dim)
hidden_dims = [init_channels] + list(block_out_channels)
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention]
num_blocks = len(block_out_channels) * [layers_per_block]
layer_params = [num_blocks, text_dims, add_self_attention]
rev_layer_params = map(reversed, layer_params)
cat_dims = []
self.num_levels = len(in_out_dims)
self.down_blocks = nn.ModuleList([])
for level, ((in_dim, out_dim), res_block_num, text_dim, self_attention) in enumerate(
zip(in_out_dims, *layer_params)
):
down_sample = level != (self.num_levels - 1)
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0))
self.down_blocks.append(
Kandinsky3DownSampleBlock(
in_dim,
out_dim,
time_embedding_dim,
text_dim,
res_block_num,
groups,
attention_head_dim,
expansion_ratio,
compression_ratio,
down_sample,
self_attention,
)
)
self.up_blocks = nn.ModuleList([])
for level, ((out_dim, in_dim), res_block_num, text_dim, self_attention) in enumerate(
zip(reversed(in_out_dims), *rev_layer_params)
):
up_sample = level != 0
self.up_blocks.append(
Kandinsky3UpSampleBlock(
in_dim,
cat_dims.pop(),
out_dim,
time_embedding_dim,
text_dim,
res_block_num,
groups,
attention_head_dim,
expansion_ratio,
compression_ratio,
up_sample,
self_attention,
)
)
self.conv_norm_out = nn.GroupNorm(groups, init_channels)
self.conv_act_out = nn.SiLU()
self.conv_out = nn.Conv2d(init_channels, out_channels, kernel_size=3, padding=1)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(Kandi3AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
# TODO(Yiyi): Clean up the following variables - these names should not be used
# but instead only the ones that we pass to forward
x = sample
context_mask = encoder_attention_mask
context = encoder_hidden_states
if not torch.is_tensor(timestep):
dtype = torch.float32 if isinstance(timestep, float) else torch.int32
timestep = torch.tensor([timestep], dtype=dtype, device=sample.device)
elif len(timestep.shape) == 0:
timestep = timestep[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(sample.shape[0])
time_embed_input = self.time_proj(timestep).to(x.dtype)
time_embed = self.time_embedding(time_embed_input)
context = self.encoder_hid_proj(context)
if context is not None:
time_embed = self.add_time_condition(time_embed, context, context_mask)
hidden_states = []
x = self.conv_in(x)
for level, down_sample in enumerate(self.down_blocks):
x = down_sample(x, time_embed, context, context_mask)
if level != self.num_levels - 1:
hidden_states.append(x)
for level, up_sample in enumerate(self.up_blocks):
if level != 0:
x = torch.cat([x, hidden_states.pop()], dim=1)
x = up_sample(x, time_embed, context, context_mask)
x = self.conv_norm_out(x)
x = self.conv_act_out(x)
x = self.conv_out(x)
if not return_dict:
return (x,)
return Kandinsky3UNetOutput(sample=x)
class Kandinsky3UpSampleBlock(nn.Module):
def __init__(
self,
in_channels,
cat_dim,
out_channels,
time_embed_dim,
context_dim=None,
num_blocks=3,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
up_sample=True,
self_attention=True,
):
super().__init__()
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1)
hidden_channels = (
[(in_channels + cat_dim, in_channels)]
+ [(in_channels, in_channels)] * (num_blocks - 2)
+ [(in_channels, out_channels)]
)
attentions = []
resnets_in = []
resnets_out = []
self.self_attention = self_attention
self.context_dim = context_dim
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed)
if self.context_dim is not None:
x = attention(x, time_embed, context, context_mask, image_mask)
x = resnet_out(x, time_embed)
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
return x
class Kandinsky3DownSampleBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
time_embed_dim,
context_dim=None,
num_blocks=3,
groups=32,
head_dim=64,
expansion_ratio=4,
compression_ratio=2,
down_sample=True,
self_attention=True,
):
super().__init__()
attentions = []
resnets_in = []
resnets_out = []
self.self_attention = self_attention
self.context_dim = context_dim
attentions.append(
set_default_layer(
self_attention,
Kandinsky3AttentionBlock,
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]]
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
)
attentions.append(
set_default_layer(
context_dim is not None,
Kandinsky3AttentionBlock,
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
)
)
resnets_out.append(
Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets_in = nn.ModuleList(resnets_in)
self.resnets_out = nn.ModuleList(resnets_out)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
if self.self_attention:
x = self.attentions[0](x, time_embed, image_mask=image_mask)
for attention, resnet_in, resnet_out in zip(self.attentions[1:], self.resnets_in, self.resnets_out):
x = resnet_in(x, time_embed)
if self.context_dim is not None:
x = attention(x, time_embed, context, context_mask, image_mask)
x = resnet_out(x, time_embed)
return x
class Kandinsky3ConditionalGroupNorm(nn.Module):
def __init__(self, groups, normalized_shape, context_dim):
super().__init__()
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
self.context_mlp = nn.Sequential(nn.SiLU(), nn.Linear(context_dim, 2 * normalized_shape))
self.context_mlp[1].weight.data.zero_()
self.context_mlp[1].bias.data.zero_()
def forward(self, x, context):
context = self.context_mlp(context)
for _ in range(len(x.shape[2:])):
context = context.unsqueeze(-1)
scale, shift = context.chunk(2, dim=1)
x = self.norm(x) * (scale + 1.0) + shift
return x
# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
# sure we can delete it and instead just pass an attention_mask
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim**-0.5
# to_q
self.to_q = nn.Linear(in_channels, out_channels, bias=False)
# to_k
self.to_k = nn.Linear(context_dim, out_channels, bias=False)
# to_v
self.to_v = nn.Linear(context_dim, out_channels, bias=False)
processor = Kandi3AttnProcessor()
self.set_processor(processor)
# to_out
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
def set_processor(self, processor: "AttnProcessor"): # noqa: F821
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, x, context, context_mask=None, image_mask=None):
return self.processor(
self,
x,
context=context,
context_mask=context_mask,
)
class Kandinsky3Block(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
super().__init__()
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
self.activation = nn.SiLU()
self.up_sample = set_default_layer(
up_resolution is not None and up_resolution,
nn.ConvTranspose2d,
(in_channels, in_channels),
{"kernel_size": 2, "stride": 2},
)
padding = int(kernel_size > 1)
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.down_sample = set_default_layer(
up_resolution is not None and not up_resolution,
nn.Conv2d,
(out_channels, out_channels),
{"kernel_size": 2, "stride": 2},
)
def forward(self, x, time_embed):
x = self.group_norm(x, time_embed)
x = self.activation(x)
x = self.up_sample(x)
x = self.projection(x)
x = self.down_sample(x)
return x
class Kandinsky3ResNetBlock(nn.Module):
def __init__(
self, in_channels, out_channels, time_embed_dim, norm_groups=32, compression_ratio=2, up_resolutions=4 * [None]
):
super().__init__()
kernel_sizes = [1, 3, 3, 1]
hidden_channel = max(in_channels, out_channels) // compression_ratio
hidden_channels = (
[(in_channels, hidden_channel)] + [(hidden_channel, hidden_channel)] * 2 + [(hidden_channel, out_channels)]
)
self.resnet_blocks = nn.ModuleList(
[
Kandinsky3Block(in_channel, out_channel, time_embed_dim, kernel_size, norm_groups, up_resolution)
for (in_channel, out_channel), kernel_size, up_resolution in zip(
hidden_channels, kernel_sizes, up_resolutions
)
]
)
self.shortcut_up_sample = set_default_layer(
True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2}
)
self.shortcut_projection = set_default_layer(
in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1}
)
self.shortcut_down_sample = set_default_layer(
False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2}
)
def forward(self, x, time_embed):
out = x
for resnet_block in self.resnet_blocks:
out = resnet_block(out, time_embed)
x = self.shortcut_up_sample(x)
x = self.shortcut_projection(x)
x = self.shortcut_down_sample(x)
x = x + out
return x
class Kandinsky3AttentionPooling(nn.Module):
def __init__(self, num_channels, context_dim, head_dim=64):
super().__init__()
self.attention = Attention(context_dim, num_channels, context_dim, head_dim)
def forward(self, x, context, context_mask=None):
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
return x + context.squeeze(1)
class Kandinsky3AttentionBlock(nn.Module):
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
super().__init__()
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim)
hidden_channels = expansion_ratio * num_channels
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.feed_forward = nn.Sequential(
nn.Conv2d(num_channels, hidden_channels, kernel_size=1, bias=False),
nn.SiLU(),
nn.Conv2d(hidden_channels, num_channels, kernel_size=1, bias=False),
)
def forward(self, x, time_embed, context=None, context_mask=None, image_mask=None):
height, width = x.shape[-2:]
out = self.in_norm(x, time_embed)
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
context = context if context is not None else out
if image_mask is not None:
mask_height, mask_width = image_mask.shape[-2:]
kernel_size = (mask_height // height, mask_width // width)
image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
image_mask = image_mask.reshape(image_mask.shape[0], -1)
out = self.attention(out, context, context_mask, image_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out
out = self.out_norm(x, time_embed)
out = self.feed_forward(out)
x = x + out
return x

View File

@@ -110,6 +110,7 @@ else:
"KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline",
]
_import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"]
_import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline",
@@ -338,6 +339,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline,
)
from .kandinsky3 import (
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline

View File

@@ -42,6 +42,7 @@ from .kandinsky2_2 import (
KandinskyV22InpaintPipeline,
KandinskyV22Pipeline,
)
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import (
@@ -64,6 +65,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFPipeline),
("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
@@ -79,6 +81,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline),

View File

@@ -0,0 +1,49 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .kandinsky3_pipeline import Kandinsky3Pipeline
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,452 @@
from typing import Callable, List, Optional, Union
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
unet: Kandinsky3UNet,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
def remove_all_hooks(self):
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
for model in [self.text_encoder, self.unet]:
if model is not None:
remove_hook_from_module(model, recurse=True)
self.unet_offload_hook = None
self.text_encoder_offload_hook = None
self.final_offload_hook = None
def process_embeds(self, embeddings, attention_mask, cut_context):
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
max_seq_length = attention_mask.sum(-1).max() + 1
embeddings = embeddings[:, :max_seq_length]
attention_mask = attention_mask[:, :max_seq_length]
return embeddings, attention_mask
@torch.no_grad()
def encode_prompt(
self,
prompt,
do_classifier_free_guidance=True,
num_images_per_prompt=1,
device=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
max_length = 128
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds, attention_mask = self.process_embeds(prompt_embeds, attention_mask, _cut_context)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
if negative_prompt is not None:
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=128,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_attention_mask = torch.zeros_like(attention_mask)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
if negative_prompt_embeds.shape != prompt_embeds.shape:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
negative_attention_mask = None
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def check_inputs(
self,
prompt,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
num_inference_steps: int = 100,
guidance_scale: float = 3.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = 1024,
width: Optional[int] = 1024,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
latents=None,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
timesteps are used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 3.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
"""
cut_context = True
device = self._execution_device
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latents
height, width = downscale_height_and_width(height, width, 8)
latents = self.prepare_latents(
(batch_size * num_images_per_prompt, 4, height, width),
prompt_embeds.dtype,
device,
generator,
latents,
self.scheduler,
)
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask,
return_dict=False,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
)
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,460 @@
import inspect
from typing import Callable, List, Optional, Union
import numpy as np
import PIL
import PIL.Image
import torch
from transformers import T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler
from ...utils import (
is_accelerate_available,
logging,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2
if height % scale_factor**2 != 0:
new_height += 1
new_width = width // scale_factor**2
if width % scale_factor**2 != 0:
new_width += 1
return new_height * scale_factor, new_width * scale_factor
def prepare_image(pil_image):
arr = np.array(pil_image.convert("RGB"))
arr = arr.astype(np.float32) / 127.5 - 1
arr = np.transpose(arr, [2, 0, 1])
image = torch.from_numpy(arr).unsqueeze(0)
return image
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
unet: Kandinsky3UNet,
scheduler: DDPMScheduler,
movq: VQModel,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer, text_encoder=text_encoder, unet=unet, scheduler=scheduler, movq=movq
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start:]
return timesteps, num_inference_steps - t_start
def remove_all_hooks(self):
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
else:
raise ImportError("Please install accelerate via `pip install accelerate`")
for model in [self.text_encoder, self.unet]:
if model is not None:
remove_hook_from_module(model, recurse=True)
self.unet_offload_hook = None
self.text_encoder_offload_hook = None
self.final_offload_hook = None
def _process_embeds(self, embeddings, attention_mask, cut_context):
# return embeddings, attention_mask
if cut_context:
embeddings[attention_mask == 0] = torch.zeros_like(embeddings[attention_mask == 0])
max_seq_length = attention_mask.sum(-1).max() + 1
embeddings = embeddings[:, :max_seq_length]
attention_mask = attention_mask[:, :max_seq_length]
return embeddings, attention_mask
@torch.no_grad()
def encode_prompt(
self,
prompt,
do_classifier_free_guidance=True,
num_images_per_prompt=1,
device=None,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
"""
if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
max_length = 128
if prompt_embeds is None:
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds, attention_mask = self._process_embeds(prompt_embeds, attention_mask, _cut_context)
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(2)
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
attention_mask = attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
if negative_prompt is not None:
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=128,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = uncond_input.input_ids.to(device)
negative_attention_mask = uncond_input.attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
text_input_ids,
attention_mask=negative_attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds[:, : prompt_embeds.shape[1]]
negative_attention_mask = negative_attention_mask[:, : prompt_embeds.shape[1]]
negative_prompt_embeds = negative_prompt_embeds * negative_attention_mask.unsqueeze(2)
else:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_attention_mask = torch.zeros_like(attention_mask)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
if negative_prompt_embeds.shape != prompt_embeds.shape:
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_attention_mask = negative_attention_mask.repeat(num_images_per_prompt, 1)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
else:
negative_prompt_embeds = None
negative_attention_mask = None
return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
init_latents = [
self.movq.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.movq.encode(image).latent_dist.sample(generator)
init_latents = self.movq.config.scaling_factor * init_latents
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
):
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
strength: float = 0.3,
num_inference_steps: int = 25,
guidance_scale: float = 3.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
latents=None,
):
cut_context = True
# 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt,
do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
device=device,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
if not isinstance(image, list):
image = [image]
if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor"
)
image = torch.cat([prepare_image(i) for i in image], dim=0)
image = image.to(dtype=prompt_embeds.dtype, device=device)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
# 5. Prepare latents
latents = self.movq.encode(image)["latents"]
latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
latents = self.prepare_latents(
latents, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
)
if hasattr(self, "text_encoder_offload_hook") and self.text_encoder_offload_hook is not None:
self.text_encoder_offload_hook.offload()
# 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask,
)[0]
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}"
)
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -77,6 +77,21 @@ class ControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Kandinsky3UNet(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -242,6 +242,36 @@ class ImageTextPipelineOutput(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Kandinsky3Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,98 @@
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)

View File

View File

@@ -0,0 +1,237 @@
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
Kandinsky3Pipeline,
Kandinsky3UNet,
VQModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
require_torch_gpu,
slow,
)
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Kandinsky3Pipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_xformers_attention = False
@property
def dummy_movq_kwargs(self):
return {
"block_out_channels": [32, 64],
"down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 1,
"norm_num_groups": 8,
"norm_type": "spatial",
"num_vq_embeddings": 12,
"out_channels": 3,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
],
"vq_embed_dim": 4,
}
@property
def dummy_movq(self):
torch.manual_seed(0)
model = VQModel(**self.dummy_movq_kwargs)
return model
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = Kandinsky3UNet(
in_channels=4,
time_embedding_dim=4,
groups=2,
attention_head_dim=4,
layers_per_block=3,
block_out_channels=(32, 64),
cross_attention_dim=4,
encoder_hid_dim=32,
)
scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
thresholding=False,
)
torch.manual_seed(0)
movq = self.dummy_movq
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"unet": unet,
"scheduler": scheduler,
"movq": movq,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"width": 16,
"height": 16,
}
return inputs
def test_kandinsky3(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_model_cpu_offload_forward_pass(self):
# TODO(Yiyi) - this test should work, skipped for time reasons for now
pass
@slow
@require_torch_gpu
class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_kandinskyV3(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
assert image.size == (1024, 1024)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))
def test_kandinskyV3_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
w, h = 512, 512
image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
prompt = "A painting of the inside of a subway train with tiny raccoons."
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
assert image.size == (512, 512)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))