|
|
|
|
@@ -0,0 +1,218 @@
|
|
|
|
|
# Copyright 2022 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved.
|
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config
|
|
|
|
|
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
Implements Algorithm 2 (Heun steps) from Karras et al. (2022). for discrete beta schedules. Based on the original
|
|
|
|
|
k-diffusion implementation by Katherine Crowson:
|
|
|
|
|
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L90
|
|
|
|
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
|
|
|
|
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
|
|
|
|
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
|
|
|
|
[`~ConfigMixin.from_config`] functions.
|
|
|
|
|
num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the
|
|
|
|
|
starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`):
|
|
|
|
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
|
|
|
|
`linear` or `scaled_linear`.
|
|
|
|
|
trained_betas (`np.ndarray`, optional):
|
|
|
|
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
|
|
|
|
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
|
|
|
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
|
|
|
|
tensor_format (`str`): whether the scheduler expects pytorch or numpy arrays.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@register_to_config
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
num_train_timesteps: int = 1000,
|
|
|
|
|
beta_start: float = 0.00085, # sensible defaults
|
|
|
|
|
beta_end: float = 0.012,
|
|
|
|
|
beta_schedule: str = "linear",
|
|
|
|
|
trained_betas: Optional[np.ndarray] = None,
|
|
|
|
|
):
|
|
|
|
|
if trained_betas is not None:
|
|
|
|
|
self.betas = torch.from_numpy(trained_betas)
|
|
|
|
|
elif beta_schedule == "linear":
|
|
|
|
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
|
|
|
|
elif beta_schedule == "scaled_linear":
|
|
|
|
|
# this schedule is very specific to the latent diffusion model.
|
|
|
|
|
self.betas = (
|
|
|
|
|
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
|
|
|
|
|
|
|
|
|
self.alphas = 1.0 - self.betas
|
|
|
|
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
|
|
|
|
|
|
|
|
# set all values
|
|
|
|
|
self.set_timesteps(num_train_timesteps, None, num_train_timesteps)
|
|
|
|
|
|
|
|
|
|
def scale_model_input(
|
|
|
|
|
self,
|
|
|
|
|
sample: torch.FloatTensor,
|
|
|
|
|
timestep: Union[float, torch.FloatTensor],
|
|
|
|
|
) -> torch.FloatTensor:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
|
|
|
|
current timestep.
|
|
|
|
|
sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep
|
|
|
|
|
Returns:
|
|
|
|
|
`torch.FloatTensor`: scaled input sample
|
|
|
|
|
"""
|
|
|
|
|
step_index = (self.timesteps == timestep).nonzero().item()
|
|
|
|
|
sigma = self.sigmas[step_index]
|
|
|
|
|
sample = sample / ((sigma**2 + 1) ** 0.5)
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
def set_timesteps(
|
|
|
|
|
self,
|
|
|
|
|
num_inference_steps: int,
|
|
|
|
|
device: Union[str, torch.device] = None,
|
|
|
|
|
num_train_timesteps: Optional[int] = None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
num_inference_steps (`int`):
|
|
|
|
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
|
|
|
|
device (`str` or `torch.device`, optional):
|
|
|
|
|
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
|
|
|
|
"""
|
|
|
|
|
self.num_inference_steps = num_inference_steps
|
|
|
|
|
|
|
|
|
|
num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps
|
|
|
|
|
|
|
|
|
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
|
|
|
|
|
|
|
|
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
|
|
|
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
|
|
|
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
|
|
|
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
|
|
|
|
|
|
|
|
|
timesteps = torch.from_numpy(timesteps)
|
|
|
|
|
|
|
|
|
|
# standard deviation of the initial noise distribution
|
|
|
|
|
self.init_noise_sigma = sigmas[0]
|
|
|
|
|
|
|
|
|
|
if str(device).startswith("mps"):
|
|
|
|
|
# mps does not support float64
|
|
|
|
|
self.timesteps = timesteps.to(device, dtype=torch.float32)
|
|
|
|
|
else:
|
|
|
|
|
self.timesteps = timesteps.to(device=device)
|
|
|
|
|
|
|
|
|
|
# empty dt and derivative
|
|
|
|
|
self.prev_derivative = None
|
|
|
|
|
self.dt = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def state_in_first_order(self):
|
|
|
|
|
return self.dt is None
|
|
|
|
|
|
|
|
|
|
def step(
|
|
|
|
|
self,
|
|
|
|
|
model_output: Union[torch.FloatTensor, np.ndarray],
|
|
|
|
|
timestep: Union[float, torch.FloatTensor],
|
|
|
|
|
sample: Union[torch.FloatTensor, np.ndarray],
|
|
|
|
|
return_dict: bool = True,
|
|
|
|
|
) -> Union[SchedulerOutput, Tuple]:
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
|
|
|
|
process from the learned model outputs (most often the predicted noise).
|
|
|
|
|
model_output (`torch.FloatTensor` or `np.ndarray`): direct output from learned diffusion model. timestep
|
|
|
|
|
(`int`): current discrete timestep in the diffusion chain. sample (`torch.FloatTensor` or `np.ndarray`):
|
|
|
|
|
current instance of sample being created by diffusion process.
|
|
|
|
|
return_dict (`bool`): option for returning tuple rather than SchedulerOutput class
|
|
|
|
|
Returns:
|
|
|
|
|
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
|
|
|
|
[`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
|
|
|
|
returning a tuple, the first element is the sample tensor.
|
|
|
|
|
"""
|
|
|
|
|
step_index = (self.timesteps == timestep).nonzero().item()
|
|
|
|
|
|
|
|
|
|
if self.state_in_first_order:
|
|
|
|
|
sigma = self.sigmas[step_index]
|
|
|
|
|
step_index += 1
|
|
|
|
|
sigma_next = self.sigmas[step_index]
|
|
|
|
|
sigma_hat = sigma
|
|
|
|
|
else:
|
|
|
|
|
# 2nd order / Heun's method
|
|
|
|
|
sigma = self.sigmas[step_index - 1]
|
|
|
|
|
sigma_next = self.sigmas[step_index]
|
|
|
|
|
sigma_hat = sigma_next
|
|
|
|
|
|
|
|
|
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
|
|
|
|
pred_original_sample = sample - sigma_hat * model_output
|
|
|
|
|
|
|
|
|
|
# 2. Convert to an ODE derivative
|
|
|
|
|
derivative = (sample - pred_original_sample) / sigma_hat
|
|
|
|
|
if self.state_in_first_order:
|
|
|
|
|
# 3. 1st order derivative
|
|
|
|
|
dt = sigma_next - sigma_hat
|
|
|
|
|
|
|
|
|
|
# store for 2nd order step
|
|
|
|
|
self.sample = sample
|
|
|
|
|
self.prev_derivative = derivative
|
|
|
|
|
self.dt = dt
|
|
|
|
|
else:
|
|
|
|
|
# 2. 2nd order / Heun's method
|
|
|
|
|
derivative = (self.prev_derivative + derivative) / 2
|
|
|
|
|
|
|
|
|
|
# 3. Retrieve 1st order derivative
|
|
|
|
|
dt = self.dt
|
|
|
|
|
|
|
|
|
|
# free dt and derivative
|
|
|
|
|
# Note, this puts the scheduler in "first order mode"
|
|
|
|
|
self.prev_derivative = None
|
|
|
|
|
self.dt = None
|
|
|
|
|
|
|
|
|
|
prev_sample = self.sample + derivative * dt
|
|
|
|
|
print(f"step_index: {step_index}, state_in_first_order: {self.state_in_first_order}, sigma: {sigma}, sigma_next: {sigma_next}, sigma_hat: {sigma_hat}, dt: {dt}")
|
|
|
|
|
|
|
|
|
|
if not return_dict:
|
|
|
|
|
return (prev_sample, self.timesteps[step_index])
|
|
|
|
|
|
|
|
|
|
return SchedulerOutput(prev_sample=prev_sample, timestep=self.timesteps[step_index])
|
|
|
|
|
|
|
|
|
|
def add_noise(
|
|
|
|
|
self,
|
|
|
|
|
original_samples: Union[torch.FloatTensor, np.ndarray],
|
|
|
|
|
noise: Union[torch.FloatTensor, np.ndarray],
|
|
|
|
|
timesteps: Union[torch.IntTensor, np.ndarray],
|
|
|
|
|
) -> Union[torch.FloatTensor, np.ndarray]:
|
|
|
|
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
|
|
|
|
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
|
|
|
self.timesteps = self.timesteps.to(original_samples.device)
|
|
|
|
|
sigma = self.sigmas[timesteps].flatten()
|
|
|
|
|
while len(sigma.shape) < len(original_samples.shape):
|
|
|
|
|
sigma = sigma.unsqueeze(-1)
|
|
|
|
|
|
|
|
|
|
noisy_samples = original_samples + noise * sigma
|
|
|
|
|
return noisy_samples
|
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
return self.config.num_train_timesteps
|