mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 15:04:45 +08:00
Compare commits
14 Commits
animatedif
...
test-fixes
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2461933857 | ||
|
|
325f6c53ed | ||
|
|
43979c2890 | ||
|
|
9ea6ac1b07 | ||
|
|
2c34c7d6dd | ||
|
|
bffadde126 | ||
|
|
11190ed09a | ||
|
|
35a969d297 | ||
|
|
c5ff469d0e | ||
|
|
bcecfbc873 | ||
|
|
6269045c5b | ||
|
|
6ca9c4af05 | ||
|
|
0532cece97 | ||
|
|
22b45304bf |
@@ -1,6 +1,6 @@
|
||||
diffusers==0.20.1
|
||||
accelerate==0.23.0
|
||||
transformers==4.34.0
|
||||
transformers==4.36.0
|
||||
peft==0.5.0
|
||||
torch==2.0.1
|
||||
torchvision>=0.16
|
||||
|
||||
@@ -22,7 +22,6 @@ import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -436,22 +435,6 @@ DATASET_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
|
||||
"""
|
||||
Returns:
|
||||
a state dict containing just the attention processor parameters.
|
||||
"""
|
||||
attn_processors = unet.attn_processors
|
||||
|
||||
attn_processors_state_dict = {}
|
||||
|
||||
for attn_processor_key, attn_processor in attn_processors.items():
|
||||
for parameter_key, parameter in attn_processor.state_dict().items():
|
||||
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
|
||||
|
||||
return attn_processors_state_dict
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
@@ -640,6 +623,17 @@ def main(args):
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
for model in models:
|
||||
for param in model.parameters():
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
@@ -1187,6 +1181,9 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Final inference
|
||||
# Make sure vae.dtype is consistent with the unet.dtype
|
||||
if args.mixed_precision == "fp16":
|
||||
vae.to(weight_dtype)
|
||||
# Load previous pipeline
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
|
||||
318
src/diffusers/models/downsampling.py
Normal file
318
src/diffusers/models/downsampling.py
Normal file
@@ -0,0 +1,318 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
from .upsampling import upfirdn2d_native
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
@@ -23,562 +23,23 @@ import torch.nn.functional as F
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .activations import get_activation
|
||||
from .attention_processor import SpatialNorm
|
||||
from .downsampling import ( # noqa
|
||||
Downsample1D,
|
||||
Downsample2D,
|
||||
FirDownsample2D,
|
||||
KDownsample2D,
|
||||
downsample_2d,
|
||||
)
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .normalization import AdaGroupNorm
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Downsample1D(nn.Module):
|
||||
"""A 1D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
return self.conv(inputs)
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample2D(nn.Module):
|
||||
"""A 2D downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
padding (`int`, default `1`):
|
||||
padding for the convolution.
|
||||
name (`str`, default `conv`):
|
||||
name of the downsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
padding: int = 1,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
if use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.Conv2d_0 = conv
|
||||
self.conv = conv
|
||||
elif name == "Conv2d_0":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.conv = conv
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv and self.padding == 0:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
if isinstance(self.conv, LoRACompatibleConv):
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Setup filter kernel.
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
|
||||
if self.use_conv:
|
||||
convH = weight.shape[2]
|
||||
convW = weight.shape[3]
|
||||
inC = weight.shape[1]
|
||||
|
||||
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
output_shape = (
|
||||
(hidden_states.shape[2] - 1) * factor + convH,
|
||||
(hidden_states.shape[3] - 1) * factor + convW,
|
||||
)
|
||||
output_padding = (
|
||||
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
inverse_conv,
|
||||
torch.tensor(kernel, device=inverse_conv.device),
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
||||
)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
||||
class FirDownsample2D(nn.Module):
|
||||
"""A 2D FIR downsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.fir_kernel = fir_kernel
|
||||
self.use_conv = use_conv
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _downsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
||||
datatype as `x`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
|
||||
if self.use_conv:
|
||||
_, _, convH, convW = weight.shape
|
||||
pad_value = (kernel.shape[0] - factor) + (convW - 1)
|
||||
stride_value = [factor, factor]
|
||||
upfirdn_input = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
|
||||
class KDownsample2D(nn.Module):
|
||||
r"""A 2D K-downsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv2d(inputs, weight, stride=2)
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
from .upsampling import ( # noqa
|
||||
FirUpsample2D,
|
||||
KUpsample2D,
|
||||
Upsample1D,
|
||||
Upsample2D,
|
||||
upfirdn2d_native,
|
||||
upsample_2d,
|
||||
)
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
@@ -894,151 +355,6 @@ class ResidualTemporalBlock1D(nn.Module):
|
||||
return out + self.residual_conv(inputs)
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def downsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Downsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
||||
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
||||
specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
|
||||
shape is a multiple of the downsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`)
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to average pooling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer downsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude.
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H // factor, W // factor]`
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * gain
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
down=factor,
|
||||
pad=((pad_value + 1) // 2, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
class TemporalConvLayer(nn.Module):
|
||||
"""
|
||||
Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from:
|
||||
|
||||
426
src/diffusers/models/upsampling.py
Normal file
426
src/diffusers/models/upsampling.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import USE_PEFT_BACKEND
|
||||
from .lora import LoRACompatibleConv
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 1D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
self.conv = None
|
||||
if use_conv_transpose:
|
||||
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
assert inputs.shape[1] == self.channels
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(inputs)
|
||||
|
||||
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
|
||||
|
||||
if self.use_conv:
|
||||
outputs = self.conv(outputs)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Upsample2D(nn.Module):
|
||||
"""A 2D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
name (`str`, default `conv`):
|
||||
name of the upsampling 2D layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
use_conv: bool = False,
|
||||
use_conv_transpose: bool = False,
|
||||
out_channels: Optional[int] = None,
|
||||
name: str = "conv",
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
|
||||
elif use_conv:
|
||||
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if name == "conv":
|
||||
self.conv = conv
|
||||
else:
|
||||
self.Conv2d_0 = conv
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
output_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
return self.conv(hidden_states)
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
# TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
|
||||
# https://github.com/pytorch/pytorch/issues/86679
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|
||||
if self.use_conv:
|
||||
if self.name == "conv":
|
||||
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.conv(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
else:
|
||||
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
|
||||
hidden_states = self.Conv2d_0(hidden_states, scale)
|
||||
else:
|
||||
hidden_states = self.Conv2d_0(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FirUpsample2D(nn.Module):
|
||||
"""A 2D FIR upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`, optional):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
|
||||
kernel for the FIR filter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
use_conv: bool = False,
|
||||
fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
|
||||
):
|
||||
super().__init__()
|
||||
out_channels = out_channels if out_channels else channels
|
||||
if use_conv:
|
||||
self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.use_conv = use_conv
|
||||
self.fir_kernel = fir_kernel
|
||||
self.out_channels = out_channels
|
||||
|
||||
def _upsample_2d(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
weight: Optional[torch.FloatTensor] = None,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""Fused `upsample_2d()` followed by `Conv2d()`.
|
||||
|
||||
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
||||
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
||||
arbitrary order.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
weight (`torch.FloatTensor`, *optional*):
|
||||
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
||||
performed by `inChannels = x.shape[0] // numGroups`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*): Integer upsampling factor (default: 2).
|
||||
gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
|
||||
datatype as `hidden_states`.
|
||||
"""
|
||||
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
|
||||
# Setup filter kernel.
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
# setup kernel
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
|
||||
if self.use_conv:
|
||||
convH = weight.shape[2]
|
||||
convW = weight.shape[3]
|
||||
inC = weight.shape[1]
|
||||
|
||||
pad_value = (kernel.shape[0] - factor) - (convW - 1)
|
||||
|
||||
stride = (factor, factor)
|
||||
# Determine data dimensions.
|
||||
output_shape = (
|
||||
(hidden_states.shape[2] - 1) * factor + convH,
|
||||
(hidden_states.shape[3] - 1) * factor + convW,
|
||||
)
|
||||
output_padding = (
|
||||
output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
|
||||
output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
|
||||
)
|
||||
assert output_padding[0] >= 0 and output_padding[1] >= 0
|
||||
num_groups = hidden_states.shape[1] // inC
|
||||
|
||||
# Transpose weights.
|
||||
weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
|
||||
weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
|
||||
weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
|
||||
|
||||
inverse_conv = F.conv_transpose2d(
|
||||
hidden_states,
|
||||
weight,
|
||||
stride=stride,
|
||||
output_padding=output_padding,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
output = upfirdn2d_native(
|
||||
inverse_conv,
|
||||
torch.tensor(kernel, device=inverse_conv.device),
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
|
||||
)
|
||||
else:
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
torch.tensor(kernel, device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.use_conv:
|
||||
height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
|
||||
height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
||||
else:
|
||||
height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
|
||||
|
||||
return height
|
||||
|
||||
|
||||
class KUpsample2D(nn.Module):
|
||||
r"""A 2D K-upsampling layer.
|
||||
|
||||
Parameters:
|
||||
pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use.
|
||||
"""
|
||||
|
||||
def __init__(self, pad_mode: str = "reflect"):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
|
||||
self.pad = kernel_1d.shape[1] // 2 - 1
|
||||
self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
|
||||
|
||||
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
||||
inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
|
||||
weight = inputs.new_zeros(
|
||||
[
|
||||
inputs.shape[1],
|
||||
inputs.shape[1],
|
||||
self.kernel.shape[0],
|
||||
self.kernel.shape[1],
|
||||
]
|
||||
)
|
||||
indices = torch.arange(inputs.shape[1], device=inputs.device)
|
||||
kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
|
||||
weight[indices, indices] = kernel
|
||||
return F.conv_transpose2d(inputs, weight, stride=2, padding=self.pad * 2 + 1)
|
||||
|
||||
|
||||
def upfirdn2d_native(
|
||||
tensor: torch.Tensor,
|
||||
kernel: torch.Tensor,
|
||||
up: int = 1,
|
||||
down: int = 1,
|
||||
pad: Tuple[int, int] = (0, 0),
|
||||
) -> torch.Tensor:
|
||||
up_x = up_y = up
|
||||
down_x = down_y = down
|
||||
pad_x0 = pad_y0 = pad[0]
|
||||
pad_x1 = pad_y1 = pad[1]
|
||||
|
||||
_, channel, in_h, in_w = tensor.shape
|
||||
tensor = tensor.reshape(-1, in_h, in_w, 1)
|
||||
|
||||
_, in_h, in_w, minor = tensor.shape
|
||||
kernel_h, kernel_w = kernel.shape
|
||||
|
||||
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
|
||||
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
|
||||
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
|
||||
|
||||
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||
out = out.to(tensor.device) # Move back to mps if necessary
|
||||
out = out[
|
||||
:,
|
||||
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
|
||||
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
|
||||
:,
|
||||
]
|
||||
|
||||
out = out.permute(0, 3, 1, 2)
|
||||
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||
out = F.conv2d(out, w)
|
||||
out = out.reshape(
|
||||
-1,
|
||||
minor,
|
||||
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
|
||||
)
|
||||
out = out.permute(0, 2, 3, 1)
|
||||
out = out[:, ::down_y, ::down_x, :]
|
||||
|
||||
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
|
||||
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
|
||||
|
||||
return out.view(-1, channel, out_h, out_w)
|
||||
|
||||
|
||||
def upsample_2d(
|
||||
hidden_states: torch.FloatTensor,
|
||||
kernel: Optional[torch.FloatTensor] = None,
|
||||
factor: int = 2,
|
||||
gain: float = 1,
|
||||
) -> torch.FloatTensor:
|
||||
r"""Upsample2D a batch of 2D images with the given filter.
|
||||
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
|
||||
filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
|
||||
`gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
|
||||
a: multiple of the upsampling factor.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`):
|
||||
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
||||
kernel (`torch.FloatTensor`, *optional*):
|
||||
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
||||
corresponds to nearest-neighbor upsampling.
|
||||
factor (`int`, *optional*, default to `2`):
|
||||
Integer upsampling factor.
|
||||
gain (`float`, *optional*, default to `1.0`):
|
||||
Scaling factor for signal magnitude (default: 1.0).
|
||||
|
||||
Returns:
|
||||
output (`torch.FloatTensor`):
|
||||
Tensor of the shape `[N, C, H * factor, W * factor]`
|
||||
"""
|
||||
assert isinstance(factor, int) and factor >= 1
|
||||
if kernel is None:
|
||||
kernel = [1] * factor
|
||||
|
||||
kernel = torch.tensor(kernel, dtype=torch.float32)
|
||||
if kernel.ndim == 1:
|
||||
kernel = torch.outer(kernel, kernel)
|
||||
kernel /= torch.sum(kernel)
|
||||
|
||||
kernel = kernel * (gain * (factor**2))
|
||||
pad_value = kernel.shape[0] - factor
|
||||
output = upfirdn2d_native(
|
||||
hidden_states,
|
||||
kernel.to(device=hidden_states.device),
|
||||
up=factor,
|
||||
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
|
||||
)
|
||||
return output
|
||||
@@ -179,12 +179,7 @@ else:
|
||||
_import_structure["stable_diffusion"].extend(
|
||||
[
|
||||
"CLIPImageProjection",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionDepth2ImgPipeline",
|
||||
"StableDiffusionDiffEditPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
"StableDiffusionImageVariationPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
@@ -193,13 +188,18 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
"StableDiffusionPanoramaPipeline",
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionSAGPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
_import_structure["stable_diffusion_gligen"] = [
|
||||
"StableDiffusionGLIGENPipeline",
|
||||
"StableDiffusionGLIGENTextImagePipeline",
|
||||
]
|
||||
_import_structure["stable_video_diffusion"] = ["StableVideoDiffusionPipeline"]
|
||||
_import_structure["stable_diffusion_xl"].extend(
|
||||
[
|
||||
@@ -209,6 +209,7 @@ else:
|
||||
"StableDiffusionXLPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
_import_structure["t2i_adapter"] = [
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionXLAdapterPipeline",
|
||||
@@ -268,7 +269,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
|
||||
_import_structure["stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
try:
|
||||
if not is_flax_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -420,11 +421,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
||||
from .stable_diffusion import (
|
||||
CLIPImageProjection,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionGLIGENPipeline,
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
StableDiffusionImageVariationPipeline,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipeline,
|
||||
@@ -433,12 +430,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionLDM3DPipeline,
|
||||
StableDiffusionPanoramaPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .stable_diffusion_xl import (
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
@@ -498,7 +498,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion import StableDiffusionKDiffusionPipeline
|
||||
from .stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not is_flax_available():
|
||||
|
||||
@@ -44,7 +44,6 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_model_editing"] = ["StableDiffusionModelEditingPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_panorama"] = ["StableDiffusionPanoramaPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_paradigms"] = ["StableDiffusionParadigmsPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_upscale"] = ["StableDiffusionUpscalePipeline"]
|
||||
_import_structure["pipeline_stable_unclip"] = ["StableUnCLIPPipeline"]
|
||||
_import_structure["pipeline_stable_unclip_img2img"] = ["StableUnCLIPImg2ImgPipeline"]
|
||||
@@ -67,37 +66,19 @@ try:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
|
||||
_dummy_objects.update(
|
||||
{
|
||||
"StableDiffusionDepth2ImgPipeline": StableDiffusionDepth2ImgPipeline,
|
||||
"StableDiffusionDiffEditPipeline": StableDiffusionDiffEditPipeline,
|
||||
"StableDiffusionPix2PixZeroPipeline": StableDiffusionPix2PixZeroPipeline,
|
||||
}
|
||||
)
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_depth2img"] = ["StableDiffusionDepth2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_pix2pix_zero"] = ["StableDiffusionPix2PixZeroPipeline"]
|
||||
try:
|
||||
if not (
|
||||
is_torch_available()
|
||||
and is_transformers_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import (
|
||||
dummy_torch_and_transformers_and_k_diffusion_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
try:
|
||||
if not (is_transformers_available() and is_onnx_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -139,13 +120,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionPipelineOutput,
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from .pipeline_stable_diffusion_attend_and_excite import (
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
|
||||
from .pipeline_stable_diffusion_gligen_text_image import (
|
||||
StableDiffusionGLIGENTextImagePipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
|
||||
from .pipeline_stable_diffusion_instruct_pix2pix import (
|
||||
@@ -156,7 +130,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .pipeline_stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline
|
||||
from .pipeline_stable_diffusion_panorama import StableDiffusionPanoramaPipeline
|
||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
from .pipeline_stable_diffusion_upscale import StableDiffusionUpscalePipeline
|
||||
from .pipeline_stable_unclip import StableUnCLIPPipeline
|
||||
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline
|
||||
@@ -181,29 +154,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
StableDiffusionDiffEditPipeline,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
)
|
||||
else:
|
||||
from .pipeline_stable_diffusion_depth2img import (
|
||||
StableDiffusionDepth2ImgPipeline,
|
||||
)
|
||||
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
|
||||
try:
|
||||
if not (
|
||||
is_torch_available()
|
||||
and is_transformers_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_k_diffusion import (
|
||||
StableDiffusionKDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_onnx_available()):
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -37,8 +37,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -40,8 +40,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
50
src/diffusers/pipelines/stable_diffusion_gligen/__init__.py
Normal file
50
src/diffusers/pipelines/stable_diffusion_gligen/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_gligen"] = ["StableDiffusionGLIGENPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_gligen_text_image"] = ["StableDiffusionGLIGENTextImagePipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
|
||||
from .pipeline_stable_diffusion_gligen_text_image import StableDiffusionGLIGENTextImagePipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -36,8 +36,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -35,9 +35,9 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .clip_image_project_model import CLIPImageProjection
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.clip_image_project_model import CLIPImageProjection
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -0,0 +1,60 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_k_diffusion_available,
|
||||
is_k_diffusion_version,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (
|
||||
is_transformers_available()
|
||||
and is_torch_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_k_diffusion"] = ["StableDiffusionKDiffusionPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (
|
||||
is_transformers_available()
|
||||
and is_torch_available()
|
||||
and is_k_diffusion_available()
|
||||
and is_k_diffusion_version(">=", "0.0.12")
|
||||
):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_k_diffusion import StableDiffusionKDiffusionPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -27,7 +27,7 @@ from ...schedulers import LMSDiscreteScheduler
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
48
src/diffusers/pipelines/stable_diffusion_sag/__init__.py
Normal file
48
src/diffusers/pipelines/stable_diffusion_sag/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_stable_diffusion_sag import StableDiffusionSAGPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -34,8 +34,8 @@ from ...utils import (
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
from ..stable_diffusion import StableDiffusionPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@@ -24,6 +25,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
@@ -1983,10 +1985,26 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
|
||||
fused_te_2_state_dict = pipe.text_encoder_2.state_dict()
|
||||
unet_state_dict = pipe.unet.state_dict()
|
||||
|
||||
peft_ge_070 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0")
|
||||
|
||||
def remap_key(key, sd):
|
||||
# some keys have moved around for PEFT >= 0.7.0, but they should still be loaded correctly
|
||||
if (key in sd) or (not peft_ge_070):
|
||||
return key
|
||||
|
||||
# instead of linear.weight, we now have linear.base_layer.weight, etc.
|
||||
if key.endswith(".weight"):
|
||||
key = key[:-7] + ".base_layer.weight"
|
||||
elif key.endswith(".bias"):
|
||||
key = key[:-5] + ".base_layer.bias"
|
||||
return key
|
||||
|
||||
for key, value in text_encoder_1_sd.items():
|
||||
key = remap_key(key, fused_te_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_state_dict[key], value))
|
||||
|
||||
for key, value in text_encoder_2_sd.items():
|
||||
key = remap_key(key, fused_te_2_state_dict)
|
||||
self.assertTrue(torch.allclose(fused_te_2_state_dict[key], value))
|
||||
|
||||
for key, value in unet_state_dict.items():
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from diffusers import OnnxStableDiffusionInpaintPipelineLegacy
|
||||
from diffusers.utils.testing_utils import (
|
||||
is_onnx_available,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_onnxruntime,
|
||||
require_torch_gpu,
|
||||
)
|
||||
|
||||
|
||||
if is_onnx_available():
|
||||
import onnxruntime as ort
|
||||
|
||||
|
||||
@nightly
|
||||
@require_onnxruntime
|
||||
@require_torch_gpu
|
||||
class StableDiffusionOnnxInpaintLegacyPipelineIntegrationTests(unittest.TestCase):
|
||||
@property
|
||||
def gpu_provider(self):
|
||||
return (
|
||||
"CUDAExecutionProvider",
|
||||
{
|
||||
"gpu_mem_limit": "15000000000", # 15GB
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
},
|
||||
)
|
||||
|
||||
@property
|
||||
def gpu_options(self):
|
||||
options = ort.SessionOptions()
|
||||
options.enable_mem_pattern = False
|
||||
return options
|
||||
|
||||
def test_inference(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/in_paint/red_cat_sitting_on_a_park_bench_onnx.npy"
|
||||
)
|
||||
|
||||
# using the PNDM scheduler by default
|
||||
pipe = OnnxStableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
revision="onnx",
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
provider=self.gpu_provider,
|
||||
sess_options=self.gpu_options,
|
||||
)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A red cat sitting on a park bench"
|
||||
|
||||
generator = np.random.RandomState(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=15,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (512, 512, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
Reference in New Issue
Block a user