mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
1 Commits
note-model
...
export-to-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97c611402f |
@@ -24,4 +24,4 @@ The abstract from the paper is:
|
||||
|
||||
## PriorTransformerOutput
|
||||
|
||||
[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput
|
||||
[[autodoc]] models.prior_transformer.PriorTransformerOutput
|
||||
|
||||
@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput
|
||||
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
|
||||
|
||||
@@ -16,8 +16,8 @@ A Transformer model for video-like data.
|
||||
|
||||
## TransformerTemporalModel
|
||||
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
|
||||
|
||||
## TransformerTemporalModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput
|
||||
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
|
||||
|
||||
@@ -104,7 +104,7 @@ accelerate launch train_text_to_image_lora.py \
|
||||
|
||||
Many of the basic and important parameters are described in the [Text-to-image](text2image#script-parameters) training guide, so this guide just focuses on the LoRA relevant parameters:
|
||||
|
||||
- `--rank`: the inner dimension of the low-rank matrices to train; a higher rank means more trainable parameters
|
||||
- `--rank`: the number of low-rank matrices to train
|
||||
- `--learning_rate`: the default learning rate is 1e-4, but with LoRA, you can use a higher learning rate
|
||||
|
||||
## Training script
|
||||
|
||||
@@ -206,13 +206,3 @@ pipe.fuse_lora(adapter_names=["pixel", "toy"])
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Saving a pipeline after fusing the adapters
|
||||
|
||||
To properly save a pipeline after it's been loaded with the adapters, it should be serialized like so:
|
||||
|
||||
```python
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.unload_lora_weights()
|
||||
pipe.save_pretrained("path-to-pipeline")
|
||||
```
|
||||
|
||||
@@ -62,7 +62,6 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) |
|
||||
| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) |
|
||||
| InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) |
|
||||
| UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -3606,32 +3605,3 @@ image = pipe(
|
||||
controlnet_conditioning_scale=0.8,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
### UFOGen Scheduler
|
||||
|
||||
[UFOGen](https://arxiv.org/abs/2311.09257) is a generative model designed for fast one-step text-to-image generation, trained via adversarial training starting from an initial pretrained diffusion model such as Stable Diffusion. `scheduling_ufogen.py` implements a onestep and multistep sampling algorithm for UFOGen models compatible with pipelines like `StableDiffusionPipeline`. A usage example is as follows:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
from scheduling_ufogen import UFOGenScheduler
|
||||
|
||||
# NOTE: currently, I am not aware of any publicly available UFOGen model checkpoints trained from SD v1.5.
|
||||
ufogen_model_id_or_path = "/path/to/ufogen/model"
|
||||
pipe = StableDiffusionPipeline(
|
||||
ufogen_model_id_or_path,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
# You can initialize a UFOGenScheduler as follows:
|
||||
pipe.scheduler = UFOGenScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
prompt = "Three cats having dinner at a table at new years eve, cinematic shot, 8k."
|
||||
|
||||
# Onestep sampling
|
||||
onestep_image = pipe(prompt, num_inference_steps=1).images[0]
|
||||
|
||||
# Multistep sampling
|
||||
multistep_image = pipe(prompt, num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
@@ -1,525 +0,0 @@
|
||||
# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UFOGen
|
||||
class UFOGenSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
|
||||
def betas_for_alpha_bar(
|
||||
num_diffusion_timesteps,
|
||||
max_beta=0.999,
|
||||
alpha_transform_type="cosine",
|
||||
):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
|
||||
Choose from `cosine` or `exp`
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
if alpha_transform_type == "cosine":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
elif alpha_transform_type == "exp":
|
||||
|
||||
def alpha_bar_fn(t):
|
||||
return math.exp(t * -12.0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
|
||||
Args:
|
||||
betas (`torch.FloatTensor`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
||||
alphas_bar_sqrt = alphas_cumprod.sqrt()
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = torch.cat([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
class UFOGenScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
`UFOGenScheduler` implements multistep and onestep sampling for a UFOGen model, introduced in
|
||||
[UFOGen: You Forward Once Large Scale Text-to-Image Generation via Diffusion GANs](https://arxiv.org/abs/2311.09257)
|
||||
by Yanwu Xu, Yang Zhao, Zhisheng Xiao, and Tingbo Hou. UFOGen is a varianet of the denoising diffusion GAN (DDGAN)
|
||||
model designed for one-step sampling.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
clip_sample (`bool`, defaults to `True`):
|
||||
Clip the predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, defaults to 1.0):
|
||||
The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, defaults to `True`):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
An offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
|
||||
Diffusion.
|
||||
rescale_betas_zero_snr (`bool`, defaults to `False`):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
denoising_step_size (`int`, defaults to 250):
|
||||
The denoising step size parameter from the UFOGen paper. The number of steps used for training is roughly
|
||||
`math.ceil(num_train_timesteps / denoising_step_size)`.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
denoising_step_size: int = 250,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
elif beta_schedule == "linear":
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
# GeoDiff sigmoid schedule
|
||||
betas = torch.linspace(-6, 6, num_train_timesteps)
|
||||
self.betas = torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
# Rescale for zero SNR
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.custom_timesteps = False
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used,
|
||||
`timesteps` must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||
timestep spacing strategy of equal spacing between timesteps is used. If `timesteps` is passed,
|
||||
`num_inference_steps` must be `None`.
|
||||
|
||||
"""
|
||||
if num_inference_steps is not None and timesteps is not None:
|
||||
raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
|
||||
|
||||
if timesteps is not None:
|
||||
for i in range(1, len(timesteps)):
|
||||
if timesteps[i] >= timesteps[i - 1]:
|
||||
raise ValueError("`custom_timesteps` must be in descending order.")
|
||||
|
||||
if timesteps[0] >= self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`timesteps` must start before `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps}."
|
||||
)
|
||||
|
||||
timesteps = np.array(timesteps, dtype=np.int64)
|
||||
self.custom_timesteps = True
|
||||
else:
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
|
||||
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.custom_timesteps = False
|
||||
|
||||
# TODO: For now, handle special case when num_inference_steps == 1 separately
|
||||
if num_inference_steps == 1:
|
||||
# Set the timestep schedule to num_train_timesteps - 1 rather than 0
|
||||
# (that is, the one-step timestep schedule is always trailing rather than leading or linspace)
|
||||
timesteps = np.array([self.config.num_train_timesteps - 1], dtype=np.int64)
|
||||
else:
|
||||
# TODO: For now, retain the DDPM timestep spacing logic
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
|
||||
.round()[::-1]
|
||||
.copy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
timesteps += self.config.steps_offset
|
||||
elif self.config.timestep_spacing == "trailing":
|
||||
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
||||
timesteps -= 1
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
"""
|
||||
"Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
|
||||
prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
|
||||
s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
|
||||
pixels from saturation at each step. We find that dynamic thresholding results in significantly better
|
||||
photorealism as well as better image-text alignment, especially when using very large guidance weights."
|
||||
|
||||
https://arxiv.org/abs/2205.11487
|
||||
"""
|
||||
dtype = sample.dtype
|
||||
batch_size, channels, *remaining_dims = sample.shape
|
||||
|
||||
if dtype not in (torch.float32, torch.float64):
|
||||
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
||||
|
||||
# Flatten sample for doing quantile calculation along each image
|
||||
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
||||
|
||||
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
||||
|
||||
s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
|
||||
s = torch.clamp(
|
||||
s, min=1, max=self.config.sample_max_value
|
||||
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
||||
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
||||
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
||||
|
||||
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
||||
sample = sample.to(dtype)
|
||||
|
||||
return sample
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UFOGenSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_ddpm.UFOGenSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_ufogen.UFOGenSchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
# 0. Resolve timesteps
|
||||
t = timestep
|
||||
prev_t = self.previous_timestep(t)
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
# beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
# current_alpha_t = alpha_prod_t / alpha_prod_t_prev
|
||||
# current_beta_t = 1 - current_alpha_t
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
|
||||
" `v_prediction` for UFOGenScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 4. Single-step or multi-step sampling
|
||||
# Noise is not used on the final timestep of the timestep schedule.
|
||||
# This also means that noise is not used for one-step sampling.
|
||||
if t != self.timesteps[-1]:
|
||||
# TODO: is this correct?
|
||||
# Sample prev sample x_{t - 1} ~ q(x_{t - 1} | x_0 = G(x_t, t))
|
||||
device = model_output.device
|
||||
noise = randn_tensor(model_output.shape, generator=generator, device=device, dtype=model_output.dtype)
|
||||
sqrt_alpha_prod_t_prev = alpha_prod_t_prev**0.5
|
||||
sqrt_one_minus_alpha_prod_t_prev = (1 - alpha_prod_t_prev) ** 0.5
|
||||
pred_prev_sample = sqrt_alpha_prod_t_prev * pred_original_sample + sqrt_one_minus_alpha_prod_t_prev * noise
|
||||
else:
|
||||
# Simply return the pred_original_sample. If `prediction_type == "sample"`, this is equivalent to returning
|
||||
# the output of the GAN generator U-Net on the initial noisy latents x_T ~ N(0, I).
|
||||
pred_prev_sample = pred_original_sample
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
|
||||
return UFOGenSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
|
||||
def get_velocity(
|
||||
self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as sample
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
|
||||
timesteps = timesteps.to(sample.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
|
||||
def previous_timestep(self, timestep):
|
||||
if self.custom_timesteps:
|
||||
index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
|
||||
if index == self.timesteps.shape[0] - 1:
|
||||
prev_t = torch.tensor(-1)
|
||||
else:
|
||||
prev_t = self.timesteps[index + 1]
|
||||
else:
|
||||
num_inference_steps = (
|
||||
self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
|
||||
)
|
||||
prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
|
||||
|
||||
return prev_t
|
||||
@@ -907,12 +907,10 @@ def main():
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
snr_loss_weights = snr_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
snr_loss_weights = snr_loss_weights / (snr + 1)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -753,7 +753,7 @@ def main(args):
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(concept["class_prompt"], num_new_images)
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
|
||||
@@ -781,13 +781,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -631,13 +631,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -664,13 +664,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -811,13 +811,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
# Consistency Training
|
||||
|
||||
`train_cm_ct_unconditional.py` trains a consistency model (CM) from scratch following the consistency training (CT) algorithm introduced in [Consistency Models](https://arxiv.org/abs/2303.01469) and refined in [Improved Techniques for Training Consistency Models](https://arxiv.org/abs/2310.14189). Both unconditional and class-conditional training are supported.
|
||||
|
||||
A usage example is as follows:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/consistency_training/train_cm_ct_unconditional.py \
|
||||
--dataset_name="cifar10" \
|
||||
--dataset_image_column_name="img" \
|
||||
--output_dir="/path/to/output/dir" \
|
||||
--mixed_precision=fp16 \
|
||||
--resolution=32 \
|
||||
--max_train_steps=1000 --max_train_samples=10000 \
|
||||
--dataloader_num_workers=8 \
|
||||
--noise_precond_type="cm" --input_precond_type="cm" \
|
||||
--train_batch_size=4 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--use_8bit_adam \
|
||||
--use_ema \
|
||||
--validation_steps=100 --eval_batch_size=4 \
|
||||
--checkpointing_steps=100 --checkpoints_total_limit=10 \
|
||||
--class_conditional --num_classes=10 \
|
||||
```
|
||||
@@ -1,6 +0,0 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
File diff suppressed because it is too large
Load Diff
@@ -741,7 +741,7 @@ def main(args):
|
||||
combined_im = train_resize(combined_im)
|
||||
|
||||
# Flipping.
|
||||
if not args.no_hflip and random.random() < 0.5:
|
||||
if not args.no_flip and random.random() < 0.5:
|
||||
combined_im = train_flip(combined_im)
|
||||
|
||||
# Cropping.
|
||||
|
||||
@@ -848,13 +848,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -943,13 +943,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -759,13 +759,12 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -35,7 +35,7 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig, set_peft_model_state_dict
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
@@ -51,13 +51,8 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import _set_state_dict_into_text_encoder, cast_training_params, compute_snr
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
@@ -634,6 +629,14 @@ def main(args):
|
||||
text_encoder_one.add_adapter(text_lora_config)
|
||||
text_encoder_two.add_adapter(text_lora_config)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -690,34 +693,18 @@ def main(args):
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, _ = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
|
||||
|
||||
if args.train_text_encoder:
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
|
||||
text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
|
||||
)
|
||||
|
||||
_set_state_dict_into_text_encoder(
|
||||
lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one_, text_encoder_two_])
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -738,13 +725,6 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [unet]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_one, text_encoder_two])
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
@@ -1082,13 +1062,12 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -1087,13 +1087,12 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -6,7 +6,7 @@ from accelerate import load_checkpoint_and_dispatch
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.vq_model import VQModel
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import tempfile
|
||||
import torch
|
||||
from accelerate import load_checkpoint_and_dispatch
|
||||
|
||||
from diffusers.models.transformers.prior_transformer import PriorTransformer
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.pipelines.shap_e import ShapERenderer
|
||||
|
||||
|
||||
|
||||
@@ -453,91 +453,3 @@ class TextualInversionLoaderMixin:
|
||||
self.enable_sequential_cpu_offload()
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
def unload_textual_inversion(
|
||||
self,
|
||||
tokens: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
r"""
|
||||
Unload Textual Inversion embeddings from the text encoder of [`StableDiffusionPipeline`]
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
|
||||
# Example 1
|
||||
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
|
||||
|
||||
# Remove all token embeddings
|
||||
pipeline.unload_textual_inversion()
|
||||
|
||||
# Example 2
|
||||
pipeline.load_textual_inversion("sd-concepts-library/moeb-style")
|
||||
pipeline.load_textual_inversion("sd-concepts-library/gta5-artwork")
|
||||
|
||||
# Remove just one token
|
||||
pipeline.unload_textual_inversion("<moe-bius>")
|
||||
```
|
||||
"""
|
||||
|
||||
tokenizer = getattr(self, "tokenizer", None)
|
||||
text_encoder = getattr(self, "text_encoder", None)
|
||||
|
||||
# Get textual inversion tokens and ids
|
||||
token_ids = []
|
||||
last_special_token_id = None
|
||||
|
||||
if tokens:
|
||||
if isinstance(tokens, str):
|
||||
tokens = [tokens]
|
||||
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
|
||||
if not added_token.special:
|
||||
if added_token.content in tokens:
|
||||
token_ids.append(added_token_id)
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
if len(token_ids) == 0:
|
||||
raise ValueError("No tokens to remove found")
|
||||
else:
|
||||
tokens = []
|
||||
for added_token_id, added_token in tokenizer.added_tokens_decoder.items():
|
||||
if not added_token.special:
|
||||
token_ids.append(added_token_id)
|
||||
tokens.append(added_token.content)
|
||||
else:
|
||||
last_special_token_id = added_token_id
|
||||
|
||||
# Delete from tokenizer
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
|
||||
# Make all token ids sequential in tokenizer
|
||||
key_id = 1
|
||||
for token_id in tokenizer.added_tokens_decoder:
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
|
||||
# Delete from text encoder
|
||||
text_embedding_dim = text_encoder.get_input_embeddings().embedding_dim
|
||||
temp_text_embedding_weights = text_encoder.get_input_embeddings().weight
|
||||
text_embedding_weights = temp_text_embedding_weights[: last_special_token_id + 1]
|
||||
to_append = []
|
||||
for i in range(last_special_token_id + 1, temp_text_embedding_weights.shape[0]):
|
||||
if i not in token_ids:
|
||||
to_append.append(temp_text_embedding_weights[i].unsqueeze(0))
|
||||
if len(to_append) > 0:
|
||||
to_append = torch.cat(to_append, dim=0)
|
||||
text_embedding_weights = torch.cat([text_embedding_weights, to_append], dim=0)
|
||||
text_embeddings_filtered = nn.Embedding(text_embedding_weights.shape[0], text_embedding_dim)
|
||||
text_embeddings_filtered.weight.data = text_embedding_weights
|
||||
text_encoder.set_input_embeddings(text_embeddings_filtered)
|
||||
|
||||
@@ -16,7 +16,6 @@ import os
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
@@ -504,9 +503,8 @@ class UNet2DConditionLoadersMixin:
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
|
||||
|
||||
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
@@ -35,10 +35,10 @@ if is_torch_available():
|
||||
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
_import_structure["modeling_utils"] = ["ModelMixin"]
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
|
||||
@@ -66,15 +66,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ConsistencyDecoderVAE,
|
||||
)
|
||||
from .controlnet import ControlNetModel
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
from .transformers import (
|
||||
DualTransformer2DModel,
|
||||
PriorTransformer,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
from .unets import (
|
||||
Kandinsky3UNet,
|
||||
MotionAdapter,
|
||||
|
||||
@@ -11,10 +11,145 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import deprecate
|
||||
from .transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
class DualTransformer2DModel(DualTransformer2DModel):
|
||||
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
|
||||
deprecate("DualTransformer2DModel", "0.29", deprecation_message)
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
|
||||
@@ -42,7 +42,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.hub_utils import PushToHubMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -377,11 +377,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
|
||||
@@ -1,12 +1,380 @@
|
||||
from ..utils import deprecate
|
||||
from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class PriorTransformerOutput(PriorTransformerOutput):
|
||||
deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead."
|
||||
deprecate("PriorTransformerOutput", "0.29", deprecation_message)
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(PriorTransformer):
|
||||
deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead."
|
||||
deprecate("PriorTransformer", "0.29", deprecation_message)
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
|
||||
@@ -11,60 +11,428 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import deprecate
|
||||
from .transformers.t5_film_transformer import (
|
||||
DecoderLayer,
|
||||
NewGELUActivation,
|
||||
T5DenseGatedActDense,
|
||||
T5FilmDecoder,
|
||||
T5FiLMLayer,
|
||||
T5LayerCrossAttention,
|
||||
T5LayerFFCond,
|
||||
T5LayerNorm,
|
||||
T5LayerSelfAttentionCond,
|
||||
)
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .attention_processor import Attention
|
||||
from .embeddings import get_timestep_embedding
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class T5FilmDecoder(T5FilmDecoder):
|
||||
deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead."
|
||||
deprecate("T5FilmDecoder", "0.29", deprecation_message)
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(DecoderLayer):
|
||||
deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead."
|
||||
deprecate("DecoderLayer", "0.29", deprecation_message)
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond):
|
||||
deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead."
|
||||
deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message)
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(T5LayerCrossAttention):
|
||||
deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead."
|
||||
deprecate("T5LayerCrossAttention", "0.29", deprecation_message)
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(T5LayerFFCond):
|
||||
deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead."
|
||||
deprecate("T5LayerFFCond", "0.29", deprecation_message)
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(T5DenseGatedActDense):
|
||||
deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead."
|
||||
deprecate("T5DenseGatedActDense", "0.29", deprecation_message)
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(T5LayerNorm):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead."
|
||||
deprecate("T5LayerNorm", "0.29", deprecation_message)
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(NewGELUActivation):
|
||||
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead."
|
||||
deprecate("NewGELUActivation", "0.29", deprecation_message)
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(T5FiLMLayer):
|
||||
deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead."
|
||||
deprecate("T5FiLMLayer", "0.29", deprecation_message)
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
@@ -11,15 +11,449 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
|
||||
from .lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from .modeling_utils import ModelMixin
|
||||
from .normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
class Transformer2DModelOutput(Transformer2DModelOutput):
|
||||
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead."
|
||||
deprecate("Transformer2DModelOutput", "0.29", deprecation_message)
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer2DModel(Transformer2DModel):
|
||||
deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead."
|
||||
deprecate("Transformer2DModel", "0.29", deprecation_message)
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -11,24 +11,369 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from ..utils import deprecate
|
||||
from .transformers.transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
TransformerTemporalModelOutput,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .modeling_utils import ModelMixin
|
||||
from .resnet import AlphaBlender
|
||||
|
||||
|
||||
class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
|
||||
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(TransformerTemporalModel):
|
||||
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModel", "0.29", deprecation_message)
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
|
||||
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead."
|
||||
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .dual_transformer_2d import DualTransformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
@@ -1,155 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
|
||||
|
||||
|
||||
class DualTransformer2DModel(nn.Module):
|
||||
"""
|
||||
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.transformers = nn.ModuleList(
|
||||
[
|
||||
Transformer2DModel(
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attention_bias=attention_bias,
|
||||
sample_size=sample_size,
|
||||
num_vector_embeds=num_vector_embeds,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
)
|
||||
for _ in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
# Variables that can be set by a pipeline:
|
||||
|
||||
# The ratio of transformer1 to transformer2's output states to be combined during inference
|
||||
self.mix_ratio = 0.5
|
||||
|
||||
# The shape of `encoder_hidden_states` is expected to be
|
||||
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
|
||||
self.condition_lengths = [77, 257]
|
||||
|
||||
# Which transformer to use to encode which condition.
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in Attention.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
|
||||
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
input_states = hidden_states
|
||||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
|
||||
output_states = output_states + input_states
|
||||
|
||||
if not return_dict:
|
||||
return (output_states,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output_states)
|
||||
@@ -1,380 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`PriorTransformer`].
|
||||
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
||||
"""
|
||||
A Prior Transformer model.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
|
||||
num_embeddings (`int`, *optional*, defaults to 77):
|
||||
The number of embeddings of the model input `hidden_states`
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
|
||||
The activation function to use to create timestep embeddings.
|
||||
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
|
||||
passing to Transformer blocks. Set it to `None` if normalization is not needed.
|
||||
embedding_proj_norm_type (`str`, *optional*, defaults to None):
|
||||
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
|
||||
needed.
|
||||
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
|
||||
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
|
||||
`encoder_hidden_states` is `None`.
|
||||
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
|
||||
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
|
||||
product between the text embedding and image embedding as proposed in the unclip paper
|
||||
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
|
||||
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
|
||||
If None, will be set to `num_attention_heads * attention_head_dim`
|
||||
embedding_proj_dim (`int`, *optional*, default to None):
|
||||
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
|
||||
clip_embed_dim (`int`, *optional*, default to None):
|
||||
The dimension of the output. If None, will be set to `embedding_dim`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
time_embed_act_fn: str = "silu",
|
||||
norm_in_type: Optional[str] = None, # layer
|
||||
embedding_proj_norm_type: Optional[str] = None, # layer
|
||||
encoder_hid_proj_type: Optional[str] = "linear", # linear
|
||||
added_emb_type: Optional[str] = "prd", # prd
|
||||
time_embed_dim: Optional[int] = None,
|
||||
embedding_proj_dim: Optional[int] = None,
|
||||
clip_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
time_embed_dim = time_embed_dim or inner_dim
|
||||
embedding_proj_dim = embedding_proj_dim or embedding_dim
|
||||
clip_embed_dim = clip_embed_dim or embedding_dim
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
if embedding_proj_norm_type is None:
|
||||
self.embedding_proj_norm = None
|
||||
elif embedding_proj_norm_type == "layer":
|
||||
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
|
||||
|
||||
if encoder_hid_proj_type is None:
|
||||
self.encoder_hidden_states_proj = None
|
||||
elif encoder_hid_proj_type == "linear":
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
else:
|
||||
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
if added_emb_type == "prd":
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
elif added_emb_type is None:
|
||||
self.prd_embedding = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if norm_in_type == "layer":
|
||||
self.norm_in = nn.LayerNorm(inner_dim)
|
||||
elif norm_in_type is None:
|
||||
self.norm_in = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`PriorTransformer`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The currently predicted image embeddings.
|
||||
timestep (`torch.LongTensor`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
if self.embedding_proj_norm is not None:
|
||||
proj_embedding = self.embedding_proj_norm(proj_embedding)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
|
||||
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
additional_embeds = []
|
||||
additional_embeddings_len = 0
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
additional_embeds.append(encoder_hidden_states)
|
||||
additional_embeddings_len += encoder_hidden_states.shape[1]
|
||||
|
||||
if len(proj_embeddings.shape) == 2:
|
||||
proj_embeddings = proj_embeddings[:, None, :]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states[:, None, :]
|
||||
|
||||
additional_embeds = additional_embeds + [
|
||||
proj_embeddings,
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states,
|
||||
]
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
additional_embeds.append(prd_embedding)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
additional_embeds,
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
|
||||
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
|
||||
if positional_embeddings.shape[1] < hidden_states.shape[1]:
|
||||
positional_embeddings = F.pad(
|
||||
positional_embeddings,
|
||||
(
|
||||
0,
|
||||
0,
|
||||
additional_embeddings_len,
|
||||
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
|
||||
),
|
||||
value=0.0,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
if self.norm_in is not None:
|
||||
hidden_states = self.norm_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
if self.prd_embedding is not None:
|
||||
hidden_states = hidden_states[:, -1]
|
||||
else:
|
||||
hidden_states = hidden_states[:, additional_embeddings_len:]
|
||||
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
@@ -1,438 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class T5FilmDecoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
T5 style decoder with FiLM conditioning.
|
||||
|
||||
Args:
|
||||
input_dims (`int`, *optional*, defaults to `128`):
|
||||
The number of input dimensions.
|
||||
targets_length (`int`, *optional*, defaults to `256`):
|
||||
The length of the targets.
|
||||
d_model (`int`, *optional*, defaults to `768`):
|
||||
Size of the input hidden states.
|
||||
num_layers (`int`, *optional*, defaults to `12`):
|
||||
The number of `DecoderLayer`'s to use.
|
||||
num_heads (`int`, *optional*, defaults to `12`):
|
||||
The number of attention heads to use.
|
||||
d_kv (`int`, *optional*, defaults to `64`):
|
||||
Size of the key-value projection vectors.
|
||||
d_ff (`int`, *optional*, defaults to `2048`):
|
||||
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
|
||||
dropout_rate (`float`, *optional*, defaults to `0.1`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int = 128,
|
||||
targets_length: int = 256,
|
||||
max_decoder_noise_time: float = 2000.0,
|
||||
d_model: int = 768,
|
||||
num_layers: int = 12,
|
||||
num_heads: int = 12,
|
||||
d_kv: int = 64,
|
||||
d_ff: int = 2048,
|
||||
dropout_rate: float = 0.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conditioning_emb = nn.Sequential(
|
||||
nn.Linear(d_model, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(d_model * 4, d_model * 4, bias=False),
|
||||
nn.SiLU(),
|
||||
)
|
||||
|
||||
self.position_encoding = nn.Embedding(targets_length, d_model)
|
||||
self.position_encoding.weight.requires_grad = False
|
||||
|
||||
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
|
||||
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.decoders = nn.ModuleList()
|
||||
for lyr_num in range(num_layers):
|
||||
# FiLM conditional T5 decoder
|
||||
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.decoders.append(lyr)
|
||||
|
||||
self.decoder_norm = T5LayerNorm(d_model)
|
||||
|
||||
self.post_dropout = nn.Dropout(p=dropout_rate)
|
||||
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
|
||||
|
||||
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
|
||||
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
|
||||
return mask.unsqueeze(-3)
|
||||
|
||||
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
|
||||
batch, _, _ = decoder_input_tokens.shape
|
||||
assert decoder_noise_time.shape == (batch,)
|
||||
|
||||
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
|
||||
time_steps = get_timestep_embedding(
|
||||
decoder_noise_time * self.config.max_decoder_noise_time,
|
||||
embedding_dim=self.config.d_model,
|
||||
max_period=self.config.max_decoder_noise_time,
|
||||
).to(dtype=self.dtype)
|
||||
|
||||
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
|
||||
|
||||
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
|
||||
|
||||
seq_length = decoder_input_tokens.shape[1]
|
||||
|
||||
# If we want to use relative positions for audio context, we can just offset
|
||||
# this sequence by the length of encodings_and_masks.
|
||||
decoder_positions = torch.broadcast_to(
|
||||
torch.arange(seq_length, device=decoder_input_tokens.device),
|
||||
(batch, seq_length),
|
||||
)
|
||||
|
||||
position_encodings = self.position_encoding(decoder_positions)
|
||||
|
||||
inputs = self.continuous_inputs_projection(decoder_input_tokens)
|
||||
inputs += position_encodings
|
||||
y = self.dropout(inputs)
|
||||
|
||||
# decoder: No padding present.
|
||||
decoder_mask = torch.ones(
|
||||
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
|
||||
)
|
||||
|
||||
# Translate encoding masks to encoder-decoder masks.
|
||||
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
|
||||
|
||||
# cross attend style: concat encodings
|
||||
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
|
||||
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
|
||||
|
||||
for lyr in self.decoders:
|
||||
y = lyr(
|
||||
y,
|
||||
conditioning_emb=conditioning_emb,
|
||||
encoder_hidden_states=encoded,
|
||||
encoder_attention_mask=encoder_decoder_mask,
|
||||
)[0]
|
||||
|
||||
y = self.decoder_norm(y)
|
||||
y = self.post_dropout(y)
|
||||
|
||||
spec_out = self.spec_out(y)
|
||||
return spec_out
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
r"""
|
||||
T5 decoder layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
|
||||
):
|
||||
super().__init__()
|
||||
self.layer = nn.ModuleList()
|
||||
|
||||
# cond self attention: layer 0
|
||||
self.layer.append(
|
||||
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
|
||||
)
|
||||
|
||||
# cross attention: layer 1
|
||||
self.layer.append(
|
||||
T5LayerCrossAttention(
|
||||
d_model=d_model,
|
||||
d_kv=d_kv,
|
||||
num_heads=num_heads,
|
||||
dropout_rate=dropout_rate,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
)
|
||||
)
|
||||
|
||||
# Film Cond MLP + dropout: last layer
|
||||
self.layer.append(
|
||||
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias=None,
|
||||
) -> Tuple[torch.FloatTensor]:
|
||||
hidden_states = self.layer[0](
|
||||
hidden_states,
|
||||
conditioning_emb=conditioning_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
|
||||
encoder_hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = self.layer[1](
|
||||
hidden_states,
|
||||
key_value_states=encoder_hidden_states,
|
||||
attention_mask=encoder_extended_attention_mask,
|
||||
)
|
||||
|
||||
# Apply Film Conditional Feed Forward layer
|
||||
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
|
||||
|
||||
return (hidden_states,)
|
||||
|
||||
|
||||
class T5LayerSelfAttentionCond(nn.Module):
|
||||
r"""
|
||||
T5 style self-attention layer with conditioning.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.layer_norm = T5LayerNorm(d_model)
|
||||
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
conditioning_emb: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
# pre_self_attention_layer_norm
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
if conditioning_emb is not None:
|
||||
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
|
||||
|
||||
# Self-attention block
|
||||
attention_output = self.attention(normed_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.dropout(attention_output)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerCrossAttention(nn.Module):
|
||||
r"""
|
||||
T5 style cross-attention layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_kv (`int`):
|
||||
Size of the key-value projection vectors.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
key_value_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
attention_output = self.attention(
|
||||
normed_hidden_states,
|
||||
encoder_hidden_states=key_value_states,
|
||||
attention_mask=attention_mask.squeeze(1),
|
||||
)
|
||||
layer_output = hidden_states + self.dropout(attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class T5LayerFFCond(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward conditional layer.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
layer_norm_epsilon (`float`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
|
||||
super().__init__()
|
||||
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
|
||||
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
|
||||
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
forwarded_states = self.layer_norm(hidden_states)
|
||||
if conditioning_emb is not None:
|
||||
forwarded_states = self.film(forwarded_states, conditioning_emb)
|
||||
|
||||
forwarded_states = self.DenseReluDense(forwarded_states)
|
||||
hidden_states = hidden_states + self.dropout(forwarded_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5DenseGatedActDense(nn.Module):
|
||||
r"""
|
||||
T5 style feed-forward layer with gated activations and dropout.
|
||||
|
||||
Args:
|
||||
d_model (`int`):
|
||||
Size of the input hidden states.
|
||||
d_ff (`int`):
|
||||
Size of the intermediate feed-forward layer.
|
||||
dropout_rate (`float`):
|
||||
Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
|
||||
self.wo = nn.Linear(d_ff, d_model, bias=False)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.act = NewGELUActivation()
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
hidden_gelu = self.act(self.wi_0(hidden_states))
|
||||
hidden_linear = self.wi_1(hidden_states)
|
||||
hidden_states = hidden_gelu * hidden_linear
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.wo(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class T5LayerNorm(nn.Module):
|
||||
r"""
|
||||
T5 style layer normalization module.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Size of the input hidden states.
|
||||
eps (`float`, `optional`, defaults to `1e-6`):
|
||||
A small value used for numerical stability to avoid dividing by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
"""
|
||||
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
||||
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
|
||||
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
||||
# half-precision inputs is done in fp32
|
||||
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
class NewGELUActivation(nn.Module):
|
||||
"""
|
||||
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
|
||||
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
|
||||
|
||||
|
||||
class T5FiLMLayer(nn.Module):
|
||||
"""
|
||||
T5 style FiLM Layer.
|
||||
|
||||
Args:
|
||||
in_features (`int`):
|
||||
Number of input features.
|
||||
out_features (`int`):
|
||||
Number of output features.
|
||||
"""
|
||||
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
super().__init__()
|
||||
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
|
||||
|
||||
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
|
||||
emb = self.scale_bias(conditioning_emb)
|
||||
scale, shift = torch.chunk(emb, 2, -1)
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
@@ -1,458 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
|
||||
from ..attention import BasicTransformerBlock
|
||||
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
|
||||
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormSingle
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
||||
distributions for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A 2D Transformer model for image-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*):
|
||||
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
|
||||
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
|
||||
added to the hidden states.
|
||||
|
||||
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlocks` attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
double_self_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_eps: float = 1e-5,
|
||||
attention_type: str = "default",
|
||||
caption_channels: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
|
||||
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
|
||||
|
||||
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = linear_cls(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
|
||||
interpolation_scale = max(interpolation_scale, 1)
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
interpolation_scale=interpolation_scale,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
double_self_attention=double_self_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
norm_eps=norm_eps,
|
||||
attention_type=attention_type,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continuous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = linear_cls(inner_dim, in_channels)
|
||||
else:
|
||||
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches and norm_type != "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
elif self.is_input_patches and norm_type == "ada_norm_single":
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
# 5. PixArt-Alpha blocks.
|
||||
self.adaln_single = None
|
||||
self.use_additional_conditions = False
|
||||
if norm_type == "ada_norm_single":
|
||||
self.use_additional_conditions = self.config.sample_size == 128
|
||||
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
|
||||
# additional conditions until we find better name
|
||||
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
|
||||
|
||||
self.caption_projection = None
|
||||
if caption_channels is not None:
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
cross_attention_kwargs: Dict[str, Any] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
The [`Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
attention_mask ( `torch.Tensor`, *optional*):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
||||
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
||||
|
||||
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
||||
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
||||
|
||||
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
||||
above. This bias will be added to the cross-attention scores.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
||||
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
||||
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# Retrieve lora scale.
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = (
|
||||
self.proj_in(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_in(hidden_states)
|
||||
)
|
||||
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
if self.adaln_single is not None:
|
||||
if self.use_additional_conditions and added_cond_kwargs is None:
|
||||
raise ValueError(
|
||||
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
|
||||
)
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
# 2. Blocks
|
||||
if self.caption_projection is not None:
|
||||
batch_size = hidden_states.shape[0]
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
timestep,
|
||||
cross_attention_kwargs,
|
||||
class_labels,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
else:
|
||||
hidden_states = (
|
||||
self.proj_out(hidden_states, scale=lora_scale)
|
||||
if not USE_PEFT_BACKEND
|
||||
else self.proj_out(hidden_states)
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if self.is_input_patches:
|
||||
if self.config.norm_type != "ada_norm_single":
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
elif self.config.norm_type == "ada_norm_single":
|
||||
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# unpatchify
|
||||
if self.adaln_single is None:
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -1,379 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput
|
||||
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..resnet import AlphaBlender
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerTemporalModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`TransformerTemporalModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
||||
The hidden states output conditioned on `encoder_hidden_states` input.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
||||
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
||||
This is fixed during training since it is used to learn a number of position embeddings.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
||||
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
||||
activation functions.
|
||||
norm_elementwise_affine (`bool`, *optional*):
|
||||
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Configure if each `TransformerBlock` should contain two self-attention layers.
|
||||
positional_embeddings: (`str`, *optional*):
|
||||
The type of positional embeddings to apply to the sequence input before passing use.
|
||||
num_positional_embeddings: (`int`, *optional*):
|
||||
The maximum length of the sequence over which to apply positional embeddings.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
norm_elementwise_affine: bool = True,
|
||||
double_self_attention: bool = True,
|
||||
positional_embeddings: Optional[str] = None,
|
||||
num_positional_embeddings: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
double_self_attention=double_self_attention,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
positional_embeddings=positional_embeddings,
|
||||
num_positional_embeddings=num_positional_embeddings,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
class_labels: torch.LongTensor = None,
|
||||
num_frames: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> TransformerTemporalModelOutput:
|
||||
"""
|
||||
The [`TransformerTemporal`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
||||
Input hidden_states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.LongTensor`, *optional*):
|
||||
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
||||
`AdaLayerZeroNorm`.
|
||||
num_frames (`int`, *optional*, defaults to 1):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, channel, height, width = hidden_states.shape
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
||||
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states[None, None, :]
|
||||
.reshape(batch_size, height, width, num_frames, channel)
|
||||
.permute(0, 3, 4, 1, 2)
|
||||
.contiguous()
|
||||
)
|
||||
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
|
||||
|
||||
class TransformerSpatioTemporalModel(nn.Module):
|
||||
"""
|
||||
A Transformer model for video-like data.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
The number of channels in the input and output (specify if the input is **continuous**).
|
||||
out_channels (`int`, *optional*):
|
||||
The number of channels in the output (specify if the input is **continuous**).
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: int = 320,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.inner_dim = inner_dim
|
||||
|
||||
# 2. Define input layers
|
||||
self.in_channels = in_channels
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_mix_inner_dim = inner_dim
|
||||
self.temporal_transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalBasicTransformerBlock(
|
||||
inner_dim,
|
||||
time_mix_inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
time_embed_dim = in_channels * 4
|
||||
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
||||
self.time_proj = Timesteps(in_channels, True, 0)
|
||||
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
# TODO: should use out_channels for continuous projections
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
image_only_indicator: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input hidden_states.
|
||||
num_frames (`int`):
|
||||
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
||||
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
||||
images, 0 indicates that the input contains video frames.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
||||
returned, otherwise a `tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
batch_frames, _, height, width = hidden_states.shape
|
||||
num_frames = image_only_indicator.shape[-1]
|
||||
batch_size = batch_frames // num_frames
|
||||
|
||||
time_context = encoder_hidden_states
|
||||
time_context_first_timestep = time_context[None, :].reshape(
|
||||
batch_size, num_frames, -1, time_context.shape[-1]
|
||||
)[:, 0]
|
||||
time_context = time_context_first_timestep[None, :].broadcast_to(
|
||||
height * width, batch_size, 1, time_context.shape[-1]
|
||||
)
|
||||
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
||||
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
||||
num_frames_emb = num_frames_emb.reshape(-1)
|
||||
t_emb = self.time_proj(num_frames_emb)
|
||||
|
||||
# `Timesteps` does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
||||
|
||||
emb = self.time_pos_embed(t_emb)
|
||||
emb = emb[:, None, :]
|
||||
|
||||
# 2. Blocks
|
||||
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
block,
|
||||
hidden_states,
|
||||
None,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
hidden_states_mix = hidden_states
|
||||
hidden_states_mix = hidden_states_mix + emb
|
||||
|
||||
hidden_states_mix = temporal_block(
|
||||
hidden_states_mix,
|
||||
num_frames=num_frames,
|
||||
encoder_hidden_states=time_context,
|
||||
)
|
||||
hidden_states = self.time_mixer(
|
||||
x_spatial=hidden_states,
|
||||
x_temporal=hidden_states_mix,
|
||||
image_only_indicator=image_only_indicator,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return TransformerTemporalModelOutput(sample=output)
|
||||
@@ -22,6 +22,7 @@ from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..normalization import AdaGroupNorm
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
@@ -33,8 +34,7 @@ from ..resnet import (
|
||||
ResnetBlockCondNorm2D,
|
||||
Upsample2D,
|
||||
)
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -20,6 +20,7 @@ from torch import nn
|
||||
from ...utils import is_torch_version
|
||||
from ...utils.torch_utils import apply_freeu
|
||||
from ..attention import Attention
|
||||
from ..dual_transformer_2d import DualTransformer2DModel
|
||||
from ..resnet import (
|
||||
Downsample2D,
|
||||
ResnetBlock2D,
|
||||
@@ -27,9 +28,8 @@ from ..resnet import (
|
||||
TemporalConvLayer,
|
||||
Upsample2D,
|
||||
)
|
||||
from ..transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ..transformers.transformer_2d import Transformer2DModel
|
||||
from ..transformers.transformer_temporal import (
|
||||
from ..transformer_2d import Transformer2DModel
|
||||
from ..transformer_temporal import (
|
||||
TransformerSpatioTemporalModel,
|
||||
TransformerTemporalModel,
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_3d_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
|
||||
@@ -29,7 +29,7 @@ from ..attention_processor import (
|
||||
)
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..transformers.transformer_temporal import TransformerTemporalModel
|
||||
from ..transformer_temporal import TransformerTemporalModel
|
||||
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_3d_blocks import (
|
||||
|
||||
@@ -35,7 +35,7 @@ from ...models.embeddings import (
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from ...models.transformers.transformer_2d import Transformer2DModel
|
||||
from ...models.transformer_2d import Transformer2DModel
|
||||
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
|
||||
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import BaseOutput, is_torch_version, logging
|
||||
|
||||
@@ -19,6 +19,7 @@ from ....models.attention_processor import (
|
||||
AttnAddedKVProcessor2_0,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ....models.dual_transformer_2d import DualTransformer2DModel
|
||||
from ....models.embeddings import (
|
||||
GaussianFourierProjection,
|
||||
ImageHintTimeEmbedding,
|
||||
@@ -31,8 +32,7 @@ from ....models.embeddings import (
|
||||
Timesteps,
|
||||
)
|
||||
from ....models.resnet import ResnetBlockCondNorm2D
|
||||
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
|
||||
from ....models.transformers.transformer_2d import Transformer2DModel
|
||||
from ....models.transformer_2d import Transformer2DModel
|
||||
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
|
||||
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ....utils.torch_utils import apply_freeu
|
||||
|
||||
@@ -172,7 +172,7 @@ class OnnxRuntimeModel:
|
||||
# load model from local directory
|
||||
if os.path.isdir(model_id):
|
||||
model = OnnxRuntimeModel.load_model(
|
||||
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
|
||||
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
||||
)
|
||||
kwargs["model_save_dir"] = Path(model_id)
|
||||
# load model from hub
|
||||
|
||||
@@ -60,7 +60,6 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
@@ -726,11 +725,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
self.save_config(save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
|
||||
model_card = populate_model_card(model_card)
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
|
||||
@@ -10,7 +10,7 @@ from ...models.attention import FeedForward
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
|
||||
from ...models.normalization import AdaLayerNorm
|
||||
from ...models.transformers.transformer_2d import Transformer2DModelOutput
|
||||
from ...models.transformer_2d import Transformer2DModelOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
||||
@@ -125,7 +125,10 @@ def export_to_video(
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4").name
|
||||
|
||||
if isinstance(video_frames[0], PIL.Image.Image):
|
||||
if isinstance(video_frames[0], np.ndarray):
|
||||
video_frames = [(frame * 255).astype(np.uint8) for frame in video_frames]
|
||||
|
||||
elif isinstance(video_frames[0], PIL.Image.Image):
|
||||
video_frames = [np.array(frame) for frame in video_frames]
|
||||
|
||||
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
|
||||
@@ -28,6 +28,7 @@ from huggingface_hub import (
|
||||
ModelCard,
|
||||
ModelCardData,
|
||||
create_repo,
|
||||
get_full_repo_name,
|
||||
hf_hub_download,
|
||||
upload_folder,
|
||||
)
|
||||
@@ -66,6 +67,7 @@ from .logging import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "model_card_template.md"
|
||||
SESSION_ID = uuid4().hex
|
||||
|
||||
|
||||
@@ -93,20 +95,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
|
||||
return ua
|
||||
|
||||
|
||||
def load_or_create_model_card(
|
||||
repo_id_or_path: Optional[str] = None, token: Optional[str] = None, is_pipeline: bool = False
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Loads or creates a model card.
|
||||
|
||||
Args:
|
||||
repo_id (`str`):
|
||||
The repo_id where to look for the model card.
|
||||
token (`str`, *optional*):
|
||||
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
|
||||
is_pipeline (`bool`, *optional*):
|
||||
Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
|
||||
"""
|
||||
def create_model_card(args, model_name):
|
||||
if not is_jinja_available():
|
||||
raise ValueError(
|
||||
"Modelcard rendering is based on Jinja templates."
|
||||
@@ -114,24 +103,45 @@ def load_or_create_model_card(
|
||||
" To install it, please run `pip install Jinja2`."
|
||||
)
|
||||
|
||||
try:
|
||||
# Check if the model card is present on the remote repo
|
||||
model_card = ModelCard.load(repo_id_or_path, token=token)
|
||||
except EntryNotFoundError:
|
||||
# Otherwise create a simple model card from template
|
||||
component = "pipeline" if is_pipeline else "model"
|
||||
model_description = f"This is the model card of a 🧨 diffusers {component} that has been pushed on the Hub. This model card has been automatically generated."
|
||||
card_data = ModelCardData()
|
||||
model_card = ModelCard.from_template(card_data, model_description=model_description)
|
||||
if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
|
||||
return
|
||||
|
||||
return model_card
|
||||
hub_token = args.hub_token if hasattr(args, "hub_token") else None
|
||||
repo_name = get_full_repo_name(model_name, token=hub_token)
|
||||
|
||||
model_card = ModelCard.from_template(
|
||||
card_data=ModelCardData( # Card metadata object that will be converted to YAML block
|
||||
language="en",
|
||||
license="apache-2.0",
|
||||
library_name="diffusers",
|
||||
tags=[],
|
||||
datasets=args.dataset_name,
|
||||
metrics=[],
|
||||
),
|
||||
template_path=MODEL_CARD_TEMPLATE_PATH,
|
||||
model_name=model_name,
|
||||
repo_name=repo_name,
|
||||
dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
|
||||
learning_rate=args.learning_rate,
|
||||
train_batch_size=args.train_batch_size,
|
||||
eval_batch_size=args.eval_batch_size,
|
||||
gradient_accumulation_steps=(
|
||||
args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None
|
||||
),
|
||||
adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
|
||||
adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
|
||||
adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
|
||||
adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
|
||||
lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
|
||||
lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
|
||||
ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
|
||||
ema_power=args.ema_power if hasattr(args, "ema_power") else None,
|
||||
ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
|
||||
mixed_precision=args.mixed_precision,
|
||||
)
|
||||
|
||||
def populate_model_card(model_card: ModelCard) -> ModelCard:
|
||||
"""Populates the `model_card` with library name."""
|
||||
if model_card.data.library_name is None:
|
||||
model_card.data.library_name = "diffusers"
|
||||
return model_card
|
||||
card_path = os.path.join(args.output_dir, "README.md")
|
||||
model_card.save(card_path)
|
||||
|
||||
|
||||
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] = None):
|
||||
@@ -425,10 +435,6 @@ class PushToHubMixin:
|
||||
"""
|
||||
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
model_card = load_or_create_model_card(repo_id, token=token)
|
||||
model_card = populate_model_card(model_card)
|
||||
|
||||
# Save all files.
|
||||
save_kwargs = {"safe_serialization": safe_serialization}
|
||||
if "Scheduler" not in self.__class__.__name__:
|
||||
@@ -437,9 +443,6 @@ class PushToHubMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
self.save_pretrained(tmpdir, **save_kwargs)
|
||||
|
||||
# Update model card if needed:
|
||||
model_card.save(os.path.join(tmpdir, "README.md"))
|
||||
|
||||
return self._upload_folder(
|
||||
tmpdir,
|
||||
repo_id,
|
||||
|
||||
50
src/diffusers/utils/model_card_template.md
Normal file
50
src/diffusers/utils/model_card_template.md
Normal file
@@ -0,0 +1,50 @@
|
||||
---
|
||||
{{ card_data }}
|
||||
---
|
||||
|
||||
<!-- This model card has been generated automatically according to the information the training script had access to. You
|
||||
should probably proofread and complete it, then remove this comment. -->
|
||||
|
||||
# {{ model_name | default("Diffusion Model") }}
|
||||
|
||||
## Model description
|
||||
|
||||
This diffusion model is trained with the [🤗 Diffusers](https://github.com/huggingface/diffusers) library
|
||||
on the `{{ dataset_name }}` dataset.
|
||||
|
||||
## Intended uses & limitations
|
||||
|
||||
#### How to use
|
||||
|
||||
```python
|
||||
# TODO: add an example code snippet for running this diffusion pipeline
|
||||
```
|
||||
|
||||
#### Limitations and bias
|
||||
|
||||
[TODO: provide examples of latent issues and potential remediations]
|
||||
|
||||
## Training data
|
||||
|
||||
[TODO: describe the data used to train the model]
|
||||
|
||||
### Training hyperparameters
|
||||
|
||||
The following hyperparameters were used during training:
|
||||
- learning_rate: {{ learning_rate }}
|
||||
- train_batch_size: {{ train_batch_size }}
|
||||
- eval_batch_size: {{ eval_batch_size }}
|
||||
- gradient_accumulation_steps: {{ gradient_accumulation_steps }}
|
||||
- optimizer: AdamW with betas=({{ adam_beta1 }}, {{ adam_beta2 }}), weight_decay={{ adam_weight_decay }} and epsilon={{ adam_epsilon }}
|
||||
- lr_scheduler: {{ lr_scheduler }}
|
||||
- lr_warmup_steps: {{ lr_warmup_steps }}
|
||||
- ema_inv_gamma: {{ ema_inv_gamma }}
|
||||
- ema_inv_gamma: {{ ema_power }}
|
||||
- ema_inv_gamma: {{ ema_max_decay }}
|
||||
- mixed_precision: {{ mixed_precision }}
|
||||
|
||||
### Training results
|
||||
|
||||
📈 [TensorBoard logs](https://huggingface.co/{{ repo_name }}/tensorboard?#scalars)
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.lora import LoRACompatibleLinear
|
||||
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
||||
from diffusers.models.transformers.transformer_2d import Transformer2DModel
|
||||
from diffusers.models.transformer_2d import Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_manual_seed,
|
||||
require_torch_accelerator_with_fp64,
|
||||
|
||||
@@ -24,8 +24,7 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import requests_mock
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from huggingface_hub import delete_repo
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@@ -733,26 +732,3 @@ class ModelPushToHubTester(unittest.TestCase):
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_jinja_available(),
|
||||
reason="Model card tests cannot be performed without Jinja installed.",
|
||||
)
|
||||
def test_push_to_hub_library_name(self):
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
model.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert model_card.library_name == "diffusers"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
@@ -15,15 +15,37 @@
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
import diffusers.utils.hub_utils
|
||||
|
||||
|
||||
class CreateModelCardTest(unittest.TestCase):
|
||||
def test_generate_model_card_with_library_name(self):
|
||||
@patch("diffusers.utils.hub_utils.get_full_repo_name")
|
||||
def test_create_model_card(self, repo_name_mock: Mock) -> None:
|
||||
repo_name_mock.return_value = "full_repo_name"
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
file_path = Path(tmpdir) / "README.md"
|
||||
file_path.write_text("---\nlibrary_name: foo\n---\nContent\n")
|
||||
model_card = load_or_create_model_card(file_path)
|
||||
populate_model_card(model_card)
|
||||
assert model_card.data.library_name == "foo"
|
||||
# Dummy args values
|
||||
args = Mock()
|
||||
args.output_dir = tmpdir
|
||||
args.local_rank = 0
|
||||
args.hub_token = "hub_token"
|
||||
args.dataset_name = "dataset_name"
|
||||
args.learning_rate = 0.01
|
||||
args.train_batch_size = 100000
|
||||
args.eval_batch_size = 10000
|
||||
args.gradient_accumulation_steps = 0.01
|
||||
args.adam_beta1 = 0.02
|
||||
args.adam_beta2 = 0.03
|
||||
args.adam_weight_decay = 0.0005
|
||||
args.adam_epsilon = 0.000001
|
||||
args.lr_scheduler = 1
|
||||
args.lr_warmup_steps = 10
|
||||
args.ema_inv_gamma = 0.001
|
||||
args.ema_power = 0.1
|
||||
args.ema_max_decay = 0.2
|
||||
args.mixed_precision = True
|
||||
|
||||
# Model card mush be rendered and saved
|
||||
diffusers.utils.hub_utils.create_model_card(args, model_name="model_name")
|
||||
self.assertTrue((Path(tmpdir) / "README.md").is_file())
|
||||
|
||||
@@ -117,9 +117,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array([0.80810547, 0.88183594, 0.9296875, 0.9189453, 0.9848633, 1.0, 0.97021484, 1.0, 1.0])
|
||||
expected_slice = np.array([0.8110, 0.8843, 0.9326, 0.9224, 0.9878, 1.0, 0.9736, 1.0, 1.0])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -127,11 +127,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.30444336, 0.26513672, 0.22436523, 0.2758789, 0.25585938, 0.20751953, 0.25390625, 0.24633789, 0.21923828]
|
||||
)
|
||||
expected_slice = np.array([0.3013, 0.2615, 0.2202, 0.2722, 0.2510, 0.2023, 0.2498, 0.2415, 0.2139])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_image_to_image(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -145,11 +143,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.22167969, 0.21875, 0.21728516, 0.22607422, 0.21948242, 0.23925781, 0.22387695, 0.25268555, 0.2722168]
|
||||
)
|
||||
expected_slice = np.array([0.2253, 0.2251, 0.2219, 0.2312, 0.2236, 0.2434, 0.2275, 0.2575, 0.2805])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -157,11 +153,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.35913086, 0.265625, 0.26367188, 0.24658203, 0.19750977, 0.39990234, 0.15258789, 0.20336914, 0.5517578]
|
||||
)
|
||||
expected_slice = np.array([0.3550, 0.2600, 0.2520, 0.2412, 0.1870, 0.3831, 0.1453, 0.1880, 0.5371])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inpainting(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -175,11 +169,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.27148438, 0.24047852, 0.22167969, 0.23217773, 0.21118164, 0.21142578, 0.21875, 0.20751953, 0.20019531]
|
||||
)
|
||||
expected_slice = np.array([0.2700, 0.2388, 0.2202, 0.2304, 0.2095, 0.2097, 0.2173, 0.2058, 0.1987])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
|
||||
|
||||
@@ -187,11 +179,9 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.27294922, 0.24023438, 0.21948242, 0.23242188, 0.20825195, 0.2055664, 0.21679688, 0.20336914, 0.19360352]
|
||||
)
|
||||
expected_slice = np.array([0.2744, 0.2410, 0.2202, 0.2334, 0.2090, 0.2053, 0.2175, 0.2033, 0.1934])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_text_to_image_model_cpu_offload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -234,10 +224,10 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.18115234, 0.13500977, 0.13427734, 0.24194336, 0.17138672, 0.16625977, 0.4260254, 0.43359375, 0.4416504]
|
||||
[0.1706543, 0.1303711, 0.12573242, 0.21777344, 0.14550781, 0.14038086, 0.40820312, 0.41455078, 0.42529297]
|
||||
)
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_unload(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
@@ -279,21 +269,9 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.09630299,
|
||||
0.09551358,
|
||||
0.08480701,
|
||||
0.09070173,
|
||||
0.09437338,
|
||||
0.09264627,
|
||||
0.08883232,
|
||||
0.09287417,
|
||||
0.09197289,
|
||||
]
|
||||
)
|
||||
expected_slice = np.array([0.0965, 0.0956, 0.0849, 0.0908, 0.0944, 0.0927, 0.0888, 0.0929, 0.0920])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
|
||||
@@ -314,11 +292,9 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.0576596, 0.05600825, 0.04479006, 0.05288461, 0.05461192, 0.05137569, 0.04867965, 0.05301541, 0.04939842]
|
||||
)
|
||||
expected_slice = np.array([0.0592, 0.0573, 0.0459, 0.0542, 0.0559, 0.0523, 0.0500, 0.0540, 0.0501])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_image_to_image_sdxl(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
|
||||
@@ -337,21 +313,9 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.06513795,
|
||||
0.07009393,
|
||||
0.07234055,
|
||||
0.07426041,
|
||||
0.07002589,
|
||||
0.06415862,
|
||||
0.07827643,
|
||||
0.07962808,
|
||||
0.07411247,
|
||||
]
|
||||
)
|
||||
expected_slice = np.array([0.0652, 0.0698, 0.0723, 0.0744, 0.0699, 0.0636, 0.0784, 0.0803, 0.0742])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
@@ -373,21 +337,9 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.07126552,
|
||||
0.07025367,
|
||||
0.07348302,
|
||||
0.07580167,
|
||||
0.07467338,
|
||||
0.06918576,
|
||||
0.07480252,
|
||||
0.08279955,
|
||||
0.08547315,
|
||||
]
|
||||
)
|
||||
expected_slice = np.array([0.0708, 0.0701, 0.0735, 0.0760, 0.0739, 0.0679, 0.0756, 0.0824, 0.0837])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inpainting_sdxl(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="sdxl_models/image_encoder")
|
||||
@@ -407,11 +359,9 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
image_slice.tolist()
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.14181179, 0.1493012, 0.14283323, 0.14602411, 0.14915377, 0.15015268, 0.14725655, 0.15009224, 0.15164584]
|
||||
)
|
||||
expected_slice = np.array([0.1420, 0.1495, 0.1430, 0.1462, 0.1493, 0.1502, 0.1474, 0.1502, 0.1517])
|
||||
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-3)
|
||||
assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
feature_extractor = self.get_image_processor("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
|
||||
|
||||
283
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
Normal file
283
tests/pipelines/stable_diffusion/test_cycle_diffusion.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = CycleDiffusionPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
|
||||
"negative_prompt",
|
||||
"height",
|
||||
"width",
|
||||
"negative_prompt_embeds",
|
||||
}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"})
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "An astronaut riding an elephant",
|
||||
"source_prompt": "An astronaut riding a horse",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"eta": 0.1,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 3,
|
||||
"source_guidance_scale": 1,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_cycle(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = CycleDiffusionPipeline(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
images = output.images
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1]
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4459, 0.4943, 0.4544, 0.6643, 0.5474, 0.4327, 0.5701, 0.5959, 0.5179])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@unittest.skipIf(torch_device != "cuda", "This test requires a GPU")
|
||||
def test_stable_diffusion_cycle_fp16(self):
|
||||
components = self.get_dummy_components()
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
components[name] = module.half()
|
||||
pipe = CycleDiffusionPipeline(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)
|
||||
images = output.images
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1]
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.3506, 0.4543, 0.446, 0.4575, 0.5195, 0.4155, 0.5273, 0.518, 0.4116])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
@skip_mps
|
||||
def test_save_load_local(self):
|
||||
return super().test_save_load_local()
|
||||
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
|
||||
@skip_mps
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
return super().test_dict_tuple_outputs_equivalent()
|
||||
|
||||
@skip_mps
|
||||
def test_save_load_optional_components(self):
|
||||
return super().test_save_load_optional_components()
|
||||
|
||||
@skip_mps
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
return super().test_attention_slicing_forward_pass()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class CycleDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_cycle_diffusion_pipeline_fp16(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/cycle-diffusion/black_colored_car.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car_fp16.npy"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(
|
||||
model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16"
|
||||
)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(image - expected_image).max() < 5e-1
|
||||
|
||||
def test_cycle_diffusion_pipeline(self):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/cycle-diffusion/black_colored_car.png"
|
||||
)
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/cycle-diffusion/blue_colored_car.npy"
|
||||
)
|
||||
init_image = init_image.resize((512, 512))
|
||||
|
||||
model_id = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
|
||||
pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None)
|
||||
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
source_prompt = "A black colored car"
|
||||
prompt = "A blue colored car"
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
source_prompt=source_prompt,
|
||||
image=init_image,
|
||||
num_inference_steps=100,
|
||||
eta=0.1,
|
||||
strength=0.85,
|
||||
guidance_scale=3,
|
||||
source_guidance_scale=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
assert np.abs(image - expected_image).max() < 2e-2
|
||||
@@ -0,0 +1,630 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
nightly,
|
||||
preprocess_image,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def dummy_image(self):
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
|
||||
return image
|
||||
|
||||
@property
|
||||
def dummy_uncond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_cond_unet_inpaint(self):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vq_model(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
def __init__(self):
|
||||
self.pixel_values = torch.ones([0])
|
||||
|
||||
def to(self, device):
|
||||
self.pixel_values.to(device)
|
||||
return self
|
||||
|
||||
return Out()
|
||||
|
||||
return extract
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_batched(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
init_images_tens = preprocess_image(init_image, batch_size=2)
|
||||
init_masks_tens = init_images_tens + 4
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
images = sd_pipe(
|
||||
[prompt] * 2,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_images_tens,
|
||||
mask_image=init_masks_tens,
|
||||
).images
|
||||
|
||||
assert images.shape == (2, 32, 32, 3)
|
||||
|
||||
image_slice_0 = images[0, -3:, -3:, -1].flatten()
|
||||
image_slice_1 = images[1, -3:, -3:, -1].flatten()
|
||||
|
||||
expected_slice_0 = np.array([0.4697, 0.3770, 0.4096, 0.4653, 0.4497, 0.4183, 0.3950, 0.4668, 0.4672])
|
||||
expected_slice_1 = np.array([0.4105, 0.4987, 0.5771, 0.4921, 0.4237, 0.5684, 0.5496, 0.4645, 0.5272])
|
||||
|
||||
assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-2
|
||||
assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
negative_prompt = "french fries"
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4941, 0.5396, 0.4689, 0.6338, 0.5392, 0.4094, 0.5477, 0.5904, 0.5165])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_num_images_per_prompt(self):
|
||||
device = "cpu"
|
||||
unet = self.dummy_cond_unet
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
|
||||
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((32, 32))
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy(
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
# test num_images_per_prompt=1 (default)
|
||||
images = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt=1 (default) for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
[prompt] * batch_size,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for single prompt
|
||||
num_images_per_prompt = 2
|
||||
images = sd_pipe(
|
||||
prompt,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
# test num_images_per_prompt for batch of prompts
|
||||
batch_size = 2
|
||||
images = sd_pipe(
|
||||
[prompt] * batch_size,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
).images
|
||||
|
||||
assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintLegacyPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, generator_device="cpu", seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "A red cat sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_pndm(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.5665, 0.6117, 0.6430, 0.4057, 0.4594, 0.5658, 0.1596, 0.3106, 0.4305])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_batched(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
inputs["prompt"] = [inputs["prompt"]] * 2
|
||||
inputs["image"] = preprocess_image(inputs["image"], batch_size=2)
|
||||
|
||||
mask = inputs["mask_image"].convert("L")
|
||||
mask = np.array(mask).astype(np.float32) / 255.0
|
||||
mask = torch.from_numpy(1 - mask)
|
||||
masks = torch.vstack([mask[None][None]] * 2)
|
||||
inputs["mask_image"] = masks
|
||||
|
||||
image = pipe(**inputs).images
|
||||
assert image.shape == (2, 512, 512, 3)
|
||||
|
||||
image_slice_0 = image[0, 253:256, 253:256, -1].flatten()
|
||||
image_slice_1 = image[1, 253:256, 253:256, -1].flatten()
|
||||
|
||||
expected_slice_0 = np.array(
|
||||
[0.52093095, 0.4176447, 0.32752383, 0.6175223, 0.50563973, 0.36470804, 0.65460044, 0.5775188, 0.44332123]
|
||||
)
|
||||
expected_slice_1 = np.array(
|
||||
[0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092]
|
||||
)
|
||||
|
||||
assert np.abs(expected_slice_0 - image_slice_0).max() < 3e-3
|
||||
assert np.abs(expected_slice_1 - image_slice_1).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_k_lms(self):
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None
|
||||
)
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.4534, 0.4467, 0.4329, 0.4329, 0.4339, 0.4220, 0.4244, 0.4332, 0.4426])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 3e-3
|
||||
|
||||
def test_stable_diffusion_inpaint_legacy_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 1:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.5977, 1.5449, 1.0586, -0.3250, 0.7383, -0.0862, 0.4631, -0.2571, -1.1289])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
elif step == 2:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.5190, 1.1621, 0.6885, 0.2424, 0.3337, -0.1617, 0.6914, -0.1957, -0.5474])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == 2
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionInpaintLegacyPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_image.png"
|
||||
)
|
||||
mask_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint/input_bench_mask.png"
|
||||
)
|
||||
inputs = {
|
||||
"prompt": "A red cat sitting on a park bench",
|
||||
"image": init_image,
|
||||
"mask_image": mask_image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 50,
|
||||
"strength": 0.75,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inpaint_pndm(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_pndm.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_ddim(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_ddim.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_lms(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_lms.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
|
||||
def test_inpaint_dpm(self):
|
||||
sd_pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)
|
||||
sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
inputs["num_inference_steps"] = 30
|
||||
image = sd_pipe(**inputs).images[0]
|
||||
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_inpaint_legacy/stable_diffusion_1_5_dpm_multi.npy"
|
||||
)
|
||||
max_diff = np.abs(expected_image - image).max()
|
||||
assert max_diff < 1e-3
|
||||
@@ -0,0 +1,255 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
StableDiffusionModelEditingPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionModelEditingPipelineFastTests(
|
||||
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
):
|
||||
pipeline_class = StableDiffusionModelEditingPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler()
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A field of roses",
|
||||
"generator": generator,
|
||||
# Setting height and width to None to prevent OOMs on CPU.
|
||||
"height": None,
|
||||
"width": None,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_model_editing_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4755, 0.5132, 0.4976, 0.3904, 0.3554, 0.4765, 0.5139, 0.5158, 0.4889])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4992, 0.5101, 0.5004, 0.3949, 0.3604, 0.4735, 0.5216, 0.5204, 0.4913])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = EulerAncestralDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4747, 0.5372, 0.4779, 0.4982, 0.5543, 0.4816, 0.5238, 0.4904, 0.5027])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_model_editing_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = PNDMScheduler()
|
||||
sd_pipe = StableDiffusionModelEditingPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# the pipeline does not expect pndm so test if it raises error.
|
||||
with self.assertRaises(ValueError):
|
||||
_ = sd_pipe(**inputs).images
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=5e-3)
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
super().test_attention_slicing_forward_pass(expected_max_diff=5e-3)
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionModelEditingSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "A field of roses",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_model_editing_default(self):
|
||||
model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
pipe = StableDiffusionModelEditingPipeline.from_pretrained(model_ckpt, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[0.6749496, 0.6386453, 0.51443267, 0.66094905, 0.61921215, 0.5491332, 0.5744417, 0.58075106, 0.5174658]
|
||||
)
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-2
|
||||
|
||||
# make sure image changes after editing
|
||||
pipe.edit_model("A pack of roses", "A pack of blue roses")
|
||||
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() > 1e-1
|
||||
|
||||
def test_stable_diffusion_model_editing_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
model_ckpt = "CompVis/stable-diffusion-v1-4"
|
||||
scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
|
||||
pipe = StableDiffusionModelEditingPipeline.from_pretrained(
|
||||
model_ckpt, scheduler=scheduler, safety_checker=None
|
||||
)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 4.4 GB is allocated
|
||||
assert mem_bytes < 4.4 * 10**9
|
||||
@@ -0,0 +1,228 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMParallelScheduler,
|
||||
DDPMParallelScheduler,
|
||||
StableDiffusionParadigmsPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class StableDiffusionParadigmsPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionParadigmsPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
)
|
||||
scheduler = DDIMParallelScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear",
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=512,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "numpy",
|
||||
"parallel": 3,
|
||||
"debug": True,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_paradigms_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4773, 0.5417, 0.4723, 0.4925, 0.5631, 0.4752, 0.5240, 0.4935, 0.5023])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_paradigms_default_case_ddpm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
torch.manual_seed(0)
|
||||
components["scheduler"] = DDPMParallelScheduler()
|
||||
torch.manual_seed(0)
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.3573, 0.4420, 0.4960, 0.4799, 0.3796, 0.3879, 0.4819, 0.4365, 0.4468])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# override to speed the overall test timing up.
|
||||
def test_inference_batch_consistent(self):
|
||||
super().test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
|
||||
# override to speed the overall test timing up.
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3e-3)
|
||||
|
||||
def test_stable_diffusion_paradigms_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionParadigmsPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.4771, 0.5420, 0.4683, 0.4918, 0.5636, 0.4725, 0.5230, 0.4923, 0.5015])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionParadigmsPipelineSlowTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.Generator(device=torch_device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "a photograph of an astronaut riding a horse",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 10,
|
||||
"guidance_scale": 7.5,
|
||||
"output_type": "numpy",
|
||||
"parallel": 3,
|
||||
"debug": True,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_paradigms_default(self):
|
||||
model_ckpt = "stabilityai/stable-diffusion-2-base"
|
||||
scheduler = DDIMParallelScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
|
||||
pipe = StableDiffusionParadigmsPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
|
||||
expected_slice = np.array([0.9622, 0.9602, 0.9748, 0.9591, 0.9630, 0.9691, 0.9661, 0.9631, 0.9741])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 1e-2
|
||||
@@ -0,0 +1,590 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMInverseScheduler,
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_image,
|
||||
load_numpy,
|
||||
load_pt,
|
||||
nightly,
|
||||
require_torch_gpu,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
TEXT_TO_IMAGE_IMAGE_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import (
|
||||
PipelineLatentTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
assert_mean_pixel_difference,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = StableDiffusionPix2PixZeroPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/src_emb_0.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/tgt_emb_0.pt"
|
||||
)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
scheduler = DDIMScheduler()
|
||||
inverse_scheduler = DDIMInverseScheduler()
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"safety_checker": None,
|
||||
"feature_extractor": None,
|
||||
"inverse_scheduler": inverse_scheduler,
|
||||
"caption_generator": None,
|
||||
"caption_processor": None,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs(self, device, seed=0):
|
||||
dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
|
||||
dummy_image = dummy_image / 2 + 0.5
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": [
|
||||
"A painting of a squirrel eating a burger",
|
||||
"A painting of a burger eating a squirrel",
|
||||
],
|
||||
"image": dummy_image.cpu(),
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"generator": generator,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"):
|
||||
inputs = self.get_dummy_inversion_inputs(device, seed)
|
||||
|
||||
if input_image_type == "pt":
|
||||
image = inputs["image"]
|
||||
elif input_image_type == "np":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
elif input_image_type == "pil":
|
||||
image = VaeImageProcessor.pt_to_numpy(inputs["image"])
|
||||
image = VaeImageProcessor.numpy_to_pil(image)
|
||||
else:
|
||||
raise ValueError(f"unsupported input_image_type {input_image_type}")
|
||||
|
||||
inputs["image"] = image
|
||||
inputs["output_type"] = output_type
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
if not hasattr(self.pipeline_class, "_optional_components"):
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# set all optional components to None and update pipeline config accordingly
|
||||
for optional_component in pipe._optional_components:
|
||||
setattr(pipe, optional_component, None)
|
||||
pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components})
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output = pipe(**inputs)[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
|
||||
pipe_loaded.to(torch_device)
|
||||
pipe_loaded.set_progress_bar_config(disable=None)
|
||||
|
||||
for optional_component in pipe._optional_components:
|
||||
self.assertTrue(
|
||||
getattr(pipe_loaded, optional_component) is None,
|
||||
f"`{optional_component}` did not stay set to None after loading.",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_loaded = pipe_loaded(**inputs)[0]
|
||||
|
||||
max_diff = np.abs(output - output_loaded).max()
|
||||
self.assertLess(max_diff, 1e-4)
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
inputs["image"] = inputs["image"][:1]
|
||||
inputs["prompt"] = inputs["prompt"][:1]
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4732, 0.4630, 0.5722, 0.5103, 0.5140, 0.5622, 0.5104, 0.5390, 0.5020])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_batch(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inversion_inputs(device)
|
||||
image = sd_pipe.invert(**inputs).images
|
||||
image_slice = image[1, -3:, -3:, -1]
|
||||
assert image.shape == (2, 32, 32, 3)
|
||||
expected_slice = np.array([0.6046, 0.5400, 0.4902, 0.4448, 0.4694, 0.5498, 0.4857, 0.5073, 0.5089])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4863, 0.5053, 0.5033, 0.4007, 0.3571, 0.4768, 0.5176, 0.5277, 0.4940])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_negative_prompt(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
negative_prompt = "french fries"
|
||||
output = sd_pipe(**inputs, negative_prompt=negative_prompt)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5177, 0.5097, 0.5047, 0.4076, 0.3667, 0.4767, 0.5238, 0.5307, 0.4958])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = EulerAncestralDiscreteScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5421, 0.5525, 0.6085, 0.5279, 0.4658, 0.5317, 0.4418, 0.4815, 0.5132])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_ddpm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
components["scheduler"] = DDPMScheduler()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.4861, 0.5053, 0.5038, 0.3994, 0.3562, 0.4768, 0.5172, 0.5280, 0.4938])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images
|
||||
output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images
|
||||
output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images
|
||||
|
||||
max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`")
|
||||
|
||||
max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max()
|
||||
self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`")
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self):
|
||||
device = torch_device
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPix2PixZeroPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images
|
||||
out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images
|
||||
out_input_pil = sd_pipe.invert(
|
||||
**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil")
|
||||
).images
|
||||
|
||||
max_diff = np.abs(out_input_pt - out_input_np).max()
|
||||
self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`")
|
||||
|
||||
assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1)
|
||||
|
||||
# Non-determinism caused by the scheduler optimizing the latent inputs during inference
|
||||
@unittest.skip("non-deterministic pipeline")
|
||||
def test_inference_batch_single_identical(self):
|
||||
return super().test_inference_batch_single_identical()
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class StableDiffusionPix2PixZeroPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.source_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat.pt"
|
||||
)
|
||||
|
||||
cls.target_embeds = load_pt(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog.pt"
|
||||
)
|
||||
|
||||
def get_inputs(self, seed=0):
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "turn him into a cyborg",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 3,
|
||||
"guidance_scale": 7.5,
|
||||
"cross_attention_guidance_amount": 0.15,
|
||||
"source_embeds": self.source_embeds,
|
||||
"target_embeds": self.target_embeds,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_default(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.5742, 0.5757, 0.5747, 0.5781, 0.5688, 0.5713, 0.5742, 0.5664, 0.5747])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_k_lms(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.6367, 0.5459, 0.5146, 0.5479, 0.4905, 0.4753, 0.4961, 0.4629, 0.4624])
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_pix2pix_zero_intermediate_state(self):
|
||||
number_of_steps = 0
|
||||
|
||||
def callback_fn(step: int, timestep: int, latents: torch.FloatTensor) -> None:
|
||||
callback_fn.has_been_called = True
|
||||
nonlocal number_of_steps
|
||||
number_of_steps += 1
|
||||
if step == 1:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.1345, 0.268, 0.1539, 0.0726, 0.0959, 0.2261, -0.2673, 0.0277, -0.2062])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
elif step == 2:
|
||||
latents = latents.detach().cpu().numpy()
|
||||
assert latents.shape == (1, 4, 64, 64)
|
||||
latents_slice = latents[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.1393, 0.2637, 0.1617, 0.0724, 0.0987, 0.2271, -0.2666, 0.0299, -0.2104])
|
||||
|
||||
assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
|
||||
|
||||
callback_fn.has_been_called = False
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
pipe(**inputs, callback=callback_fn, callback_steps=1)
|
||||
assert callback_fn.has_been_called
|
||||
assert number_of_steps == 3
|
||||
|
||||
def test_stable_diffusion_pipeline_with_sequential_cpu_offloading(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
pipe.enable_attention_slicing(1)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs()
|
||||
_ = pipe(**inputs)
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
# make sure that less than 8.2 GB is allocated
|
||||
assert mem_bytes < 8.2 * 10**9
|
||||
|
||||
|
||||
@nightly
|
||||
@require_torch_gpu
|
||||
class InversionPipelineNightlyTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
raw_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png"
|
||||
)
|
||||
|
||||
raw_image = raw_image.convert("RGB").resize((512, 512))
|
||||
|
||||
cls.raw_image = raw_image
|
||||
|
||||
def test_stable_diffusion_pix2pix_inversion(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert inv_latents.shape == (1, 4, 64, 64)
|
||||
expected_slice = np.array([0.8447, -0.0730, 0.7588, -1.2070, -0.4678, 0.1511, -0.8555, 1.1816, -0.7666])
|
||||
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_2_pix2pix_inversion(self):
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator, num_inference_steps=10)
|
||||
inv_latents = output[0]
|
||||
|
||||
image_slice = inv_latents[0, -3:, -3:, -1].flatten()
|
||||
|
||||
assert inv_latents.shape == (1, 4, 64, 64)
|
||||
expected_slice = np.array([0.8970, -0.1611, 0.4766, -1.1162, -0.5923, 0.1050, -0.9678, 1.0537, -0.6050])
|
||||
|
||||
assert np.abs(expected_slice - image_slice.cpu().numpy()).max() < 5e-2
|
||||
|
||||
def test_stable_diffusion_2_pix2pix_full(self):
|
||||
# numpy array of https://huggingface.co/datasets/hf-internal-testing/diffusers-images/blob/main/pix2pix/dog_2.png
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/dog_2.npy"
|
||||
)
|
||||
|
||||
pipe = StableDiffusionPix2PixZeroPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
caption = "a photography of a cat with flowers"
|
||||
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
output = pipe.invert(caption, image=self.raw_image, generator=generator)
|
||||
inv_latents = output[0]
|
||||
|
||||
source_prompts = 4 * ["a cat sitting on the street", "a cat playing in the field", "a face of a cat"]
|
||||
target_prompts = 4 * ["a dog sitting on the street", "a dog playing in the field", "a face of a dog"]
|
||||
|
||||
source_embeds = pipe.get_embeds(source_prompts)
|
||||
target_embeds = pipe.get_embeds(target_prompts)
|
||||
|
||||
image = pipe(
|
||||
caption,
|
||||
source_embeds=source_embeds,
|
||||
target_embeds=target_embeds,
|
||||
num_inference_steps=125,
|
||||
cross_attention_guidance_amount=0.015,
|
||||
generator=generator,
|
||||
latents=inv_latents,
|
||||
negative_prompt=caption,
|
||||
output_type="np",
|
||||
).images
|
||||
|
||||
mean_diff = np.abs(expected_image - image).mean()
|
||||
assert mean_diff < 0.25
|
||||
@@ -13,8 +13,7 @@ from typing import Callable, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from huggingface_hub import ModelCard, delete_repo
|
||||
from huggingface_hub.utils import is_jinja_available
|
||||
from huggingface_hub import delete_repo
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import diffusers
|
||||
@@ -1143,21 +1142,6 @@ class PipelinePushToHubTester(unittest.TestCase):
|
||||
# Reset repo
|
||||
delete_repo(self.org_repo_id, token=TOKEN)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_jinja_available(),
|
||||
reason="Model card tests cannot be performed without Jinja installed.",
|
||||
)
|
||||
def test_push_to_hub_library_name(self):
|
||||
components = self.get_pipeline_components()
|
||||
pipeline = StableDiffusionPipeline(**components)
|
||||
pipeline.push_to_hub(self.repo_id, token=TOKEN)
|
||||
|
||||
model_card = ModelCard.load(f"{USER}/{self.repo_id}", token=TOKEN).data
|
||||
assert model_card.library_name == "diffusers"
|
||||
|
||||
# Reset repo
|
||||
delete_repo(self.repo_id, token=TOKEN)
|
||||
|
||||
|
||||
# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders
|
||||
# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()`
|
||||
|
||||
@@ -15,9 +15,7 @@ ALWAYS_TEST_PIPELINE_MODULES = [
|
||||
"stable_diffusion",
|
||||
"stable_diffusion_2",
|
||||
"stable_diffusion_xl",
|
||||
"stable_diffusion_adapter",
|
||||
"deepfloyd_if",
|
||||
"ip_adapters",
|
||||
"kandinsky",
|
||||
"kandinsky2_2",
|
||||
"text_to_video_synthesis",
|
||||
|
||||
Reference in New Issue
Block a user