mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Asymmetric vqgan (#3956)
* added AsymmetricAutoencoderKL * fixed copies+dummy * added script to convert original asymmetric vqgan * added docs * updated docs * fixed style * fixes, added tests * update doc * fixed doc * fixed tests * naming Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * naming Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * udpated code example * updated doc * comments fixes * added docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * comments fixes * added inpaint pipeline tests * comment suggestion: delete method * yet another fixes --------- Co-authored-by: Ruslan Vorovchenko <r.vorovchenko@prequelapp.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
2551b73670
commit
07f1fbb18e
@@ -166,6 +166,8 @@
|
||||
title: VQModel
|
||||
- local: api/models/autoencoderkl
|
||||
title: AutoencoderKL
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/transformer2d
|
||||
title: Transformer2D
|
||||
- local: api/models/transformer_temporal
|
||||
|
||||
55
docs/source/en/api/models/asymmetricautoencoderkl.mdx
Normal file
55
docs/source/en/api/models/asymmetricautoencoderkl.mdx
Normal file
@@ -0,0 +1,55 @@
|
||||
# AsymmetricAutoencoderKL
|
||||
|
||||
Improved larger variational autoencoder (VAE) model with KL loss for inpainting task: [Designing a Better Asymmetric VQGAN for StableDiffusion](https://arxiv.org/abs/2306.04632) by Zixin Zhu, Xuelu Feng, Dongdong Chen, Jianmin Bao, Le Wang, Yinpeng Chen, Lu Yuan, Gang Hua.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*StableDiffusion is a revolutionary text-to-image generator that is causing a stir in the world of image generation and editing. Unlike traditional methods that learn a diffusion model in pixel space, StableDiffusion learns a diffusion model in the latent space via a VQGAN, ensuring both efficiency and quality. It not only supports image generation tasks, but also enables image editing for real images, such as image inpainting and local editing. However, we have observed that the vanilla VQGAN used in StableDiffusion leads to significant information loss, causing distortion artifacts even in non-edited image regions. To this end, we propose a new asymmetric VQGAN with two simple designs. Firstly, in addition to the input from the encoder, the decoder contains a conditional branch that incorporates information from task-specific priors, such as the unmasked image region in inpainting. Secondly, the decoder is much heavier than the encoder, allowing for more detailed recovery while only slightly increasing the total inference cost. The training cost of our asymmetric VQGAN is cheap, and we only need to retrain a new asymmetric decoder while keeping the vanilla VQGAN encoder and StableDiffusion unchanged. Our asymmetric VQGAN can be widely used in StableDiffusion-based inpainting and local editing methods. Extensive experiments demonstrate that it can significantly improve the inpainting and editing performance, while maintaining the original text-to-image capability. The code is available at https://github.com/buxiangzhiren/Asymmetric_VQGAN*
|
||||
|
||||
Evaluation results can be found in section 4.1 of the original paper.
|
||||
|
||||
## Available checkpoints
|
||||
|
||||
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-1-5)
|
||||
* [https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2](https://huggingface.co/cross-attention/asymmetric-autoencoder-kl-x-2)
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import requests
|
||||
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
|
||||
|
||||
|
||||
def download_image(url: str) -> Image.Image:
|
||||
response = requests.get(url)
|
||||
return Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
|
||||
prompt = "a photo of a person"
|
||||
img_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/celeba_hq_256.png"
|
||||
mask_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/repaint/mask_256.png"
|
||||
|
||||
image = download_image(img_url).resize((256, 256))
|
||||
mask_image = download_image(mask_url).resize((256, 256))
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
|
||||
pipe.vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(prompt=prompt, image=image, mask_image=mask_image).images[0]
|
||||
image.save("image.jpeg")
|
||||
```
|
||||
|
||||
## AsymmetricAutoencoderKL
|
||||
|
||||
[[autodoc]] models.autoencoder_asym_kl.AsymmetricAutoencoderKL
|
||||
|
||||
## AutoencoderKLOutput
|
||||
|
||||
[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.vae.DecoderOutput
|
||||
184
scripts/convert_asymmetric_vqgan_to_diffusers.py
Normal file
184
scripts/convert_asymmetric_vqgan_to_diffusers.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AsymmetricAutoencoderKL
|
||||
|
||||
|
||||
ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"down_block_out_channels": [128, 256, 512, 512],
|
||||
"layers_per_down_block": 2,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
"up_block_out_channels": [192, 384, 768, 768],
|
||||
"layers_per_up_block": 3,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 256,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
|
||||
ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
],
|
||||
"down_block_out_channels": [128, 256, 512, 512],
|
||||
"layers_per_down_block": 2,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
],
|
||||
"up_block_out_channels": [256, 512, 1024, 1024],
|
||||
"layers_per_up_block": 5,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 256,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
|
||||
|
||||
def convert_asymmetric_autoencoder_kl_state_dict(original_state_dict: Dict[str, Any]) -> Dict[str, Any]:
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if k.startswith("encoder."):
|
||||
converted_state_dict[
|
||||
k.replace("encoder.down.", "encoder.down_blocks.")
|
||||
.replace("encoder.mid.", "encoder.mid_block.")
|
||||
.replace("encoder.norm_out.", "encoder.conv_norm_out.")
|
||||
.replace(".downsample.", ".downsamplers.0.")
|
||||
.replace(".nin_shortcut.", ".conv_shortcut.")
|
||||
.replace(".block.", ".resnets.")
|
||||
.replace(".block_1.", ".resnets.0.")
|
||||
.replace(".block_2.", ".resnets.1.")
|
||||
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
||||
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
||||
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
||||
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
||||
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
||||
] = v
|
||||
elif k.startswith("decoder.") and "up_layers" not in k:
|
||||
converted_state_dict[
|
||||
k.replace("decoder.encoder.", "decoder.condition_encoder.")
|
||||
.replace(".norm_out.", ".conv_norm_out.")
|
||||
.replace(".up.0.", ".up_blocks.3.")
|
||||
.replace(".up.1.", ".up_blocks.2.")
|
||||
.replace(".up.2.", ".up_blocks.1.")
|
||||
.replace(".up.3.", ".up_blocks.0.")
|
||||
.replace(".block.", ".resnets.")
|
||||
.replace("mid", "mid_block")
|
||||
.replace(".0.upsample.", ".0.upsamplers.0.")
|
||||
.replace(".1.upsample.", ".1.upsamplers.0.")
|
||||
.replace(".2.upsample.", ".2.upsamplers.0.")
|
||||
.replace(".nin_shortcut.", ".conv_shortcut.")
|
||||
.replace(".block_1.", ".resnets.0.")
|
||||
.replace(".block_2.", ".resnets.1.")
|
||||
.replace(".attn_1.k.", ".attentions.0.to_k.")
|
||||
.replace(".attn_1.q.", ".attentions.0.to_q.")
|
||||
.replace(".attn_1.v.", ".attentions.0.to_v.")
|
||||
.replace(".attn_1.proj_out.", ".attentions.0.to_out.0.")
|
||||
.replace(".attn_1.norm.", ".attentions.0.group_norm.")
|
||||
] = v
|
||||
elif k.startswith("quant_conv."):
|
||||
converted_state_dict[k] = v
|
||||
elif k.startswith("post_quant_conv."):
|
||||
converted_state_dict[k] = v
|
||||
else:
|
||||
print(f" skipping key `{k}`")
|
||||
# fix weights shape
|
||||
for k, v in converted_state_dict.items():
|
||||
if (
|
||||
(k.startswith("encoder.mid_block.attentions.0") or k.startswith("decoder.mid_block.attentions.0"))
|
||||
and k.endswith("weight")
|
||||
and ("to_q" in k or "to_k" in k or "to_v" in k or "to_out" in k)
|
||||
):
|
||||
converted_state_dict[k] = converted_state_dict[k][:, :, 0, 0]
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
||||
scale: Literal["1.5", "2"], original_checkpoint_path: str, map_location: torch.device
|
||||
) -> AsymmetricAutoencoderKL:
|
||||
print("Loading original state_dict")
|
||||
original_state_dict = torch.load(original_checkpoint_path, map_location=map_location)
|
||||
original_state_dict = original_state_dict["state_dict"]
|
||||
print("Converting state_dict")
|
||||
converted_state_dict = convert_asymmetric_autoencoder_kl_state_dict(original_state_dict)
|
||||
kwargs = ASYMMETRIC_AUTOENCODER_KL_x_1_5_CONFIG if scale == "1.5" else ASYMMETRIC_AUTOENCODER_KL_x_2_CONFIG
|
||||
print("Initializing AsymmetricAutoencoderKL model")
|
||||
asymmetric_autoencoder_kl = AsymmetricAutoencoderKL(**kwargs)
|
||||
print("Loading weight from converted state_dict")
|
||||
asymmetric_autoencoder_kl.load_state_dict(converted_state_dict)
|
||||
asymmetric_autoencoder_kl.eval()
|
||||
print("AsymmetricAutoencoderKL successfully initialized")
|
||||
return asymmetric_autoencoder_kl
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start = time.time()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--scale",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Asymmetric VQGAN scale: `1.5` or `2`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the original Asymmetric VQGAN checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to save pretrained AsymmetricAutoencoderKL model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--map_location",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading the checkpoint",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.scale in ["1.5", "2"], f"{args.scale} should be `1.5` of `2`"
|
||||
assert Path(args.original_checkpoint_path).is_file()
|
||||
|
||||
asymmetric_autoencoder_kl = get_asymmetric_autoencoder_kl_from_original_checkpoint(
|
||||
scale=args.scale,
|
||||
original_checkpoint_path=args.original_checkpoint_path,
|
||||
map_location=torch.device(args.map_location),
|
||||
)
|
||||
print("Saving pretrained AsymmetricAutoencoderKL")
|
||||
asymmetric_autoencoder_kl.save_pretrained(args.output_path)
|
||||
print(f"Done in {time.time() - start:.2f} seconds")
|
||||
@@ -36,6 +36,7 @@ except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .models import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
ControlNetModel,
|
||||
ModelMixin,
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
|
||||
from .autoencoder_kl import AutoencoderKL
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
|
||||
180
src/diffusers/models/autoencoder_asym_kl.py
Normal file
180
src/diffusers/models/autoencoder_asym_kl.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import apply_forward_hook
|
||||
from .autoencoder_kl import AutoencoderKLOutput
|
||||
from .modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
Designing a Better Asymmetric VQGAN for StableDiffusion https://arxiv.org/abs/2306.04632 . A VAE model with KL loss
|
||||
for encoding images into latents and decoding latent representations into images.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
||||
for all models (such as downloading or saving).
|
||||
|
||||
Parameters:
|
||||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
|
||||
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
|
||||
Tuple of downsample block types.
|
||||
down_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of down block output channels.
|
||||
layers_per_down_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for down block.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
|
||||
Tuple of upsample block types.
|
||||
up_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
|
||||
Tuple of up block output channels.
|
||||
layers_per_up_block (`int`, *optional*, defaults to `1`):
|
||||
Number layers for up block.
|
||||
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
||||
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`):
|
||||
Number of groups to use for the first normalization layer in ResNet blocks.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
||||
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
||||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 3,
|
||||
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||
down_block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_down_block: int = 1,
|
||||
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||
up_block_out_channels: Tuple[int] = (64,),
|
||||
layers_per_up_block: int = 1,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 4,
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# pass init params to Encoder
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=latent_channels,
|
||||
down_block_types=down_block_types,
|
||||
block_out_channels=down_block_out_channels,
|
||||
layers_per_block=layers_per_down_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
double_z=True,
|
||||
)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = MaskConditionDecoder(
|
||||
in_channels=latent_channels,
|
||||
out_channels=out_channels,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=up_block_out_channels,
|
||||
layers_per_block=layers_per_up_block,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z, image, mask)
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(
|
||||
self,
|
||||
z: torch.FloatTensor,
|
||||
image: Optional[torch.FloatTensor] = None,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
decoded = self._decode(z, image, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
mask: Optional[torch.FloatTensor] = None,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): Input sample.
|
||||
mask (`torch.FloatTensor`, *optional*, defaults to `None`): Optional inpainting mask.
|
||||
sample_posterior (`bool`, *optional*, defaults to `False`):
|
||||
Whether to sample from the posterior.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
||||
"""
|
||||
x = sample
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z, sample, mask).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
@@ -280,6 +280,253 @@ class Decoder(nn.Module):
|
||||
return sample
|
||||
|
||||
|
||||
class UpSample(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
||||
x = torch.relu(x)
|
||||
x = self.deconv(x)
|
||||
return x
|
||||
|
||||
|
||||
class MaskConditionEncoder(nn.Module):
|
||||
"""
|
||||
used in AsymmetricAutoencoderKL
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_ch: int,
|
||||
out_ch: int = 192,
|
||||
res_ch: int = 768,
|
||||
stride: int = 16,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
channels = []
|
||||
while stride > 1:
|
||||
stride = stride // 2
|
||||
in_ch_ = out_ch * 2
|
||||
if out_ch > res_ch:
|
||||
out_ch = res_ch
|
||||
if stride == 1:
|
||||
in_ch_ = res_ch
|
||||
channels.append((in_ch_, out_ch))
|
||||
out_ch *= 2
|
||||
|
||||
out_channels = []
|
||||
for _in_ch, _out_ch in channels:
|
||||
out_channels.append(_out_ch)
|
||||
out_channels.append(channels[-1][0])
|
||||
|
||||
layers = []
|
||||
in_ch_ = in_ch
|
||||
for l in range(len(out_channels)):
|
||||
out_ch_ = out_channels[l]
|
||||
if l == 0 or l == 1:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
|
||||
else:
|
||||
layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
|
||||
in_ch_ = out_ch_
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
|
||||
out = {}
|
||||
for l in range(len(self.layers)):
|
||||
layer = self.layers[l]
|
||||
x = layer(x)
|
||||
out[str(tuple(x.shape))] = x
|
||||
x = torch.relu(x)
|
||||
return out
|
||||
|
||||
|
||||
class MaskConditionDecoder(nn.Module):
|
||||
"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
|
||||
decoder with a conditioner on the mask and masked image."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(64,),
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
act_fn="silu",
|
||||
norm_type="group", # group, spatial
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[-1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
temb_channels = in_channels if norm_type == "spatial" else None
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=1,
|
||||
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
temb_channels=temb_channels,
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=self.layers_per_block + 1,
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=None,
|
||||
add_upsample=not is_final_block,
|
||||
resnet_eps=1e-6,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attention_head_dim=output_channel,
|
||||
temb_channels=temb_channels,
|
||||
resnet_time_scale_shift=norm_type,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# condition encoder
|
||||
self.condition_encoder = MaskConditionEncoder(
|
||||
in_ch=out_channels,
|
||||
out_ch=block_out_channels[0],
|
||||
res_ch=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# out
|
||||
if norm_type == "spatial":
|
||||
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z, image=None, mask=None, latent_embeds=None):
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask, use_reentrant=False
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False
|
||||
)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block), sample, latent_embeds
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.condition_encoder), masked_image, mask
|
||||
)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# condition encoder
|
||||
if image is not None and mask is not None:
|
||||
masked_image = (1 - mask) * image
|
||||
im_x = self.condition_encoder(masked_image, mask)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
if image is not None and mask is not None:
|
||||
sample_ = im_x[str(tuple(sample.shape))]
|
||||
mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
|
||||
sample = sample * mask_ + sample_ * (1 - mask_)
|
||||
sample = up_block(sample, latent_embeds)
|
||||
if image is not None and mask is not None:
|
||||
sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
|
||||
|
||||
# post-process
|
||||
if latent_embeds is None:
|
||||
sample = self.conv_norm_out(sample)
|
||||
else:
|
||||
sample = self.conv_norm_out(sample, latent_embeds)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class VectorQuantizer(nn.Module):
|
||||
"""
|
||||
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -25,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -180,7 +179,7 @@ class StableDiffusionInpaintPipeline(
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||
@@ -203,7 +202,7 @@ class StableDiffusionInpaintPipeline(
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
@@ -513,20 +512,6 @@ class StableDiffusionInpaintPipeline(
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
|
||||
def decode_latents(self, latents):
|
||||
warnings.warn(
|
||||
"The decode_latents method is deprecated and will be removed in a future version. Please"
|
||||
" use VaeImageProcessor instead",
|
||||
FutureWarning,
|
||||
)
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -897,6 +882,7 @@ class StableDiffusionInpaintPipeline(
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True
|
||||
)
|
||||
mask_condition = mask.clone()
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
@@ -1007,7 +993,14 @@ class StableDiffusionInpaintPipeline(
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
condition_kwargs = {}
|
||||
if isinstance(self.vae, AsymmetricAutoencoderKL):
|
||||
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
|
||||
init_image_condition = init_image.clone()
|
||||
init_image = self._encode_vae_image(init_image, generator=generator)
|
||||
mask_condition = mask_condition.to(device=device, dtype=masked_image_latents.dtype)
|
||||
condition_kwargs = {"image": init_image_condition, "mask": mask_condition}
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, **condition_kwargs)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
from ..utils import DummyObject, requires_backends
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKL(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import AsymmetricAutoencoderKL, AutoencoderKL
|
||||
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.testing_utils import enable_full_determinism
|
||||
@@ -173,6 +173,56 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
mask = torch.ones((batch_size, 1) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image, "mask": mask}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
"down_block_out_channels": [32, 64],
|
||||
"layers_per_down_block": 1,
|
||||
"up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
"up_block_out_channels": [32, 64],
|
||||
"layers_per_up_block": 1,
|
||||
"act_fn": "silu",
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": 32,
|
||||
"scaling_factor": 0.18215,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_forward_signature(self):
|
||||
pass
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
@@ -402,3 +452,149 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
output_slice_2 = sample_2[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
|
||||
assert torch_all_close(output_slice_1, output_slice_2, atol=3e-3)
|
||||
|
||||
|
||||
@slow
|
||||
class AsymmetricAutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_sd_vae_model(self, model_id="cross-attention/asymmetric-autoencoder-kl-x-1-5", fp16=False):
|
||||
revision = "main"
|
||||
torch_dtype = torch.float32
|
||||
|
||||
model = AsymmetricAutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
torch_dtype=torch_dtype,
|
||||
revision=revision,
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
return model
|
||||
|
||||
def get_generator(self, seed=0):
|
||||
if torch_device == "mps":
|
||||
return torch.manual_seed(seed)
|
||||
return torch.Generator(device=torch_device).manual_seed(seed)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078], [-0.1603, 0.9878, -0.0495, -0.0790, -0.2709, 0.8375, -0.2060, -0.0824]],
|
||||
[47, [0.4400, 0.0543, 0.2873, 0.2946, 0.0553, 0.0839, -0.1585, 0.2529], [-0.2376, 0.1168, 0.1332, -0.4840, -0.2508, -0.0791, -0.0493, -0.4089]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image, generator=generator, sample_posterior=True).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.0340, 0.2870, 0.1698, -0.0105, -0.3448, 0.3529, -0.1321, 0.1097], [-0.0344, 0.2912, 0.1687, -0.0137, -0.3462, 0.3552, -0.1337, 0.1078]],
|
||||
[47, [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531], [0.4397, 0.0550, 0.2873, 0.2946, 0.0567, 0.0855, -0.1580, 0.2531]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(image).sample
|
||||
|
||||
assert sample.shape == image.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[13, [-0.0521, -0.2939, 0.1540, -0.1855, -0.5936, -0.3138, -0.4579, -0.2275]],
|
||||
[37, [-0.1820, -0.4345, -0.0455, -0.2923, -0.8035, -0.5089, -0.4795, -0.3106]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_stable_diffusion_decode(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
output_slice = sample[-1, -2:, :2, -2:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=2e-3)
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model.decode(encoding).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
with torch.no_grad():
|
||||
sample_2 = model.decode(encoding).sample
|
||||
|
||||
assert list(sample.shape) == [3, 3, 512, 512]
|
||||
|
||||
assert torch_all_close(sample, sample_2, atol=5e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, [-0.3001, 0.0918, -2.6984, -3.9720, -3.2099, -5.0353, 1.7338, -0.2065, 3.4267]],
|
||||
[47, [-1.5030, -4.3871, -6.0355, -9.1157, -1.6661, -2.7853, 2.1607, -5.0823, 2.5633]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
def test_stable_diffusion_encode_sample(self, seed, expected_slice):
|
||||
model = self.get_sd_vae_model()
|
||||
image = self.get_sd_image(seed)
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
with torch.no_grad():
|
||||
dist = model.encode(image).latent_dist
|
||||
sample = dist.sample(generator=generator)
|
||||
|
||||
assert list(sample.shape) == [image.shape[0], 4] + [i // 8 for i in image.shape[2:]]
|
||||
|
||||
output_slice = sample[0, -1, -3:, -3:].flatten().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
tolerance = 3e-3 if torch_device != "mps" else 1e-2
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=tolerance)
|
||||
|
||||
@@ -25,6 +25,7 @@ from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AsymmetricAutoencoderKL,
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
@@ -552,6 +553,230 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
|
||||
assert np.max(np.abs(image - image_ckpt)) < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "Face of a yellow cat, high resolution, sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_inpaint_ddim(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0521, 0.0606, 0.0602, 0.0446, 0.0495, 0.0434, 0.1175, 0.1290, 0.1431])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-4
|
||||
|
||||
def test_stable_diffusion_inpaint_fp16(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1343, 0.1406, 0.1440, 0.1504, 0.1729, 0.0989, 0.1807, 0.2822, 0.1179])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_pndm(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.0976, 0.1071, 0.1119, 0.1363, 0.1260, 0.1150, 0.3745, 0.3586, 0.3340])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_k_lms(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.8909, 0.8620, 0.9024, 0.8501, 0.8558, 0.9074, 0.8790, 0.7540, 0.9003])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device, dtype=torch.float16)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 2.45 GB is allocated
|
||||
assert mem_bytes < 2.45 * 10**9
|
||||
|
||||
@require_torch_2
|
||||
def test_inpaint_compile(self):
|
||||
pass
|
||||
|
||||
def test_stable_diffusion_inpaint_pil_input_resolution_test(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5",
|
||||
)
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
# change input image to a random size (one that would cause a tensor mismatch error)
|
||||
inputs["image"] = inputs["image"].resize((127, 127))
|
||||
inputs["mask_image"] = inputs["mask_image"].resize((127, 127))
|
||||
inputs["height"] = 128
|
||||
inputs["width"] = 128
|
||||
image = pipe(**inputs).images
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, inputs["height"], inputs["width"], 3)
|
||||
|
||||
def test_stable_diffusion_inpaint_strength_test(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-inpainting", safety_checker=None
|
||||
)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
# change input strength
|
||||
inputs["strength"] = 0.75
|
||||
image = pipe(**inputs).images
|
||||
# verify that the returned image has the same height and width as the input height and width
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
expected_slice = np.array([0.2458, 0.2576, 0.3124, 0.2679, 0.2669, 0.2796, 0.2872, 0.2975, 0.2661])
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_simple_inpaint_ddim(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained("cross-attention/asymmetric-autoencoder-kl-x-1-5")
|
||||
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None)
|
||||
pipe.vae = vae
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.3312, 0.4052, 0.4103, 0.4153, 0.4347, 0.4154, 0.4932, 0.4920, 0.4431])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 6e-4
|
||||
|
||||
def test_download_local(self):
|
||||
vae = AsymmetricAutoencoderKL.from_pretrained(
|
||||
"cross-attention/asymmetric-autoencoder-kl-x-1-5", torch_dtype=torch.float16
|
||||
)
|
||||
filename = hf_hub_download("runwayml/stable-diffusion-inpainting", filename="sd-v1-5-inpainting.ckpt")
|
||||
|
||||
pipe = StableDiffusionInpaintPipeline.from_single_file(filename, torch_dtype=torch.float16)
|
||||
pipe.vae = vae
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 1
|
||||
image_out = pipe(**inputs).images[0]
|
||||
|
||||
assert image_out.shape == (512, 512, 3)
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
pass
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintPipelineNightlyTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user