Compare commits

...

20 Commits

Author SHA1 Message Date
Suraj Patil
79ec3a8a39 Merge branch 'main' into v_prediction 2022-11-24 02:53:54 +01:00
Nathan Lambert
da5e677c18 remove Literal, add deprecates 2022-11-23 12:20:54 -08:00
Nathan Lambert
b70f6cd5e0 move expand_to_shape 2022-11-23 11:59:15 -08:00
Nathan Lambert
66951ec084 Update src/diffusers/schedulers/scheduling_ddpm.py
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
2022-11-23 07:42:21 -08:00
Nathan Lambert
172b242c2a fix loose comments 2022-11-17 14:58:52 -08:00
Nathan Lambert
e701a97838 change name from v to velocity 2022-11-17 14:56:19 -08:00
Nathan Lambert
c1a0584213 style 2022-11-17 14:51:59 -08:00
Nathan Lambert
3adf87b2d9 add ddim pred type test 2022-11-17 14:49:55 -08:00
Nathan Lambert
5a509dbedd Merge branch 'main' into v_prediction 2022-11-17 14:47:26 -08:00
Nathan Lambert
e39198306b fix tests 2022-11-17 14:43:14 -08:00
Ben Glickenhaus
11362ae5d2 V prediction ddim (#1313)
* v diffusion support for ddpm

* quality and style

* variable name consistency

* missing base case

* pass prediction type along in the pipeline

* put prediction type in scheduler config

* style

* try to train on ddim

* changes to ddim

* ddim v prediction works to train butterflies example

* fix bad merge, style and quality

* try to fix broken doc strings

* second pass

* one more

* white space

* Update src/diffusers/schedulers/scheduling_ddim.py

* remove extra lines

* Update src/diffusers/schedulers/scheduling_ddim.py

Co-authored-by: Ben Glickenhaus <ben@mail.cs.umass.edu>
Co-authored-by: Nathan Lambert <nathan@huggingface.co>
2022-11-17 10:26:19 -08:00
Nathan Lambert
56164f56fb quality 2022-11-09 11:53:25 -08:00
Nathan Lambert
8fe2ff4b16 Merge branch 'main' into v_prediction 2022-11-09 11:50:39 -08:00
Ben Glickenhaus
f00d896a1e DDPM changes to support v diffusion (#1121)
* v diffusion support for ddpm

* quality and style

* variable name consistency

* missing base case

* pass prediction type along in the pipeline

* put prediction type in scheduler config

* style
2022-11-09 11:33:15 -08:00
Nathan Lambert
ac6be90a71 style 2022-10-18 11:42:51 -07:00
Nathan Lambert
4c6850473d add ddim 2022-10-18 11:22:46 -07:00
Nathan Lambert
3eb2593d9a a few more additions 2022-10-12 20:10:03 -07:00
Nathan Lambert
7eb4bfae6c up 2022-10-12 17:39:48 -07:00
Nathan Lambert
b7d0c1e84a placeholder code 2022-10-12 17:32:52 -07:00
Nathan Lambert
798263f629 init v-pred pr 2022-10-12 17:24:36 -07:00
7 changed files with 443 additions and 61 deletions

View File

@@ -194,16 +194,28 @@ def parse_args():
)
parser.add_argument(
"--predict_epsilon",
action="store_true",
default=True,
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
"--prediction_type",
type=str,
default="epsilon",
help=(
"Whether the model should predict the 'epsilon'/noise error, directly the reconstructed image 'x0', or the"
" velocity of the ODE 'velocity'."
),
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
args = parser.parse_args()
message = (
"Please make sure to instantiate your training with `--prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=args)
if predict_epsilon:
args.prediction_type = "epsilon"
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
@@ -256,13 +268,13 @@ def main(args):
"UpBlock2D",
),
)
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_predict_epsilon:
if accepts_prediction_type:
noise_scheduler = DDPMScheduler(
num_train_timesteps=args.ddpm_num_steps,
beta_schedule=args.ddpm_beta_schedule,
predict_epsilon=args.predict_epsilon,
prediction_type=args.prediction_type,
)
else:
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
@@ -365,7 +377,7 @@ def main(args):
# Predict the noise residual
model_output = model(noisy_images, timesteps).sample
if args.predict_epsilon:
if args.prediction_type == "epsilon":
loss = F.mse_loss(model_output, noise) # this could have different weights!
else:
alpha_t = _extract_into_tensor(

View File

@@ -0,0 +1,227 @@
import glob
import os
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_cosine_schedule_with_warmup
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
@dataclass
class TrainingConfig:
image_size = 128 # the generated image resolution
train_batch_size = 16
eval_batch_size = 16 # how many images to sample during evaluation
num_epochs = 50
gradient_accumulation_steps = 1
learning_rate = 5e-5
lr_warmup_steps = 500
save_image_epochs = 10
save_model_epochs = 30
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
push_to_hub = False # whether to upload the saved model to the HF Hub
hub_private_repo = False
overwrite_output_dir = True # overwrite the old model when re-running the notebook
seed = 0
config = TrainingConfig()
config.dataset_name = "huggan/smithsonian_butterflies_subset"
dataset = load_dataset(config.dataset_name, split="train")
preprocess = transforms.Compose(
[
transforms.Resize((config.image_size, config.image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
model = UNet2DModel(
sample_size=config.image_size, # the target image resolution
in_channels=3, # the number of input channels, 3 for RGB images
out_channels=3, # the number of output channels
layers_per_block=2, # how many ResNet layers to use per UNet block
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
down_block_types=(
"DownBlock2D", # a regular ResNet downsampling block
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
"DownBlock2D",
),
up_block_types=(
"UpBlock2D", # a regular ResNet upsampling block
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
if config.output_dir.startswith("ddpm"):
noise_scheduler = DDPMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="velocity",
)
else:
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="squaredcos_cap_v2",
variance_type="v_diffusion",
prediction_type="velocity",
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
def make_grid(images, rows, cols):
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, image in enumerate(images):
grid.paste(image, box=(i % cols * w, i // cols * h))
return grid
def evaluate(config, epoch, pipeline):
# Sample some images from random noise (this is the backward diffusion process).
# The default pipeline output type is `List[PIL.Image]`
images = pipeline(
batch_size=config.eval_batch_size,
generator=torch.manual_seed(config.seed),
).images
# Make a grid out of the images
image_grid = make_grid(images, rows=4, cols=4)
# Save the images
test_dir = os.path.join(config.output_dir, "samples")
os.makedirs(test_dir, exist_ok=True)
image_grid.save(f"{test_dir}/{epoch:04d}.png")
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
# Initialize accelerator and tensorboard logging
accelerator = Accelerator(
mixed_precision=config.mixed_precision,
gradient_accumulation_steps=config.gradient_accumulation_steps,
log_with="tensorboard",
logging_dir=os.path.join(config.output_dir, "logs"),
)
if accelerator.is_main_process:
if config.push_to_hub:
repo = init_git_repo(config, at_init=True)
accelerator.init_trackers("train_example")
# Prepare everything
# There is no specific order to remember, you just need to unpack the
# objects in the same order you gave them to the prepare method.
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
global_step = 0
if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
evaluate(config, 0, pipeline)
# Now you train the model
for epoch in range(config.num_epochs):
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
progress_bar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
# Sample noise to add to the images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
with accelerator.accumulate(model):
# Predict the noise residual
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
z_t = alpha_t * clean_images + sigma_t * noise
noise_pred = model(z_t, timesteps).sample
v = alpha_t * noise - sigma_t * clean_images
loss = F.mse_loss(noise_pred, v)
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
global_step += 1
# After each epoch you optionally sample some demo images with evaluate() and save the model
if accelerator.is_main_process:
if config.output_dir.startswith("ddpm"):
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
else:
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
evaluate(config, epoch, pipeline)
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
if config.push_to_hub:
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
else:
pipeline.save_pretrained(config.output_dir)
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
train_loop(*args)
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])

View File

@@ -24,7 +24,7 @@ import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, expand_to_shape
@dataclass
@@ -75,6 +75,18 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
return torch.tensor(betas)
def t_to_alpha_sigma(num_diffusion_timesteps):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
alphas = torch.cos(
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
)
sigmas = torch.sin(
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
)
return alphas, sigmas
class DDIMScheduler(SchedulerMixin, ConfigMixin):
"""
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
@@ -106,6 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
an offset added to the inference steps. You can use a combination of `offset=1` and
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
stable diffusion.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
https://imagen.research.google/video/paper.pdf)
"""
@@ -121,7 +137,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
variance_type: str = "fixed",
steps_offset: int = 0,
prediction_type: str = "epsilon",
**kwargs,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@@ -138,14 +157,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.variance_type = variance_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
if prediction_type == "velocity":
self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
if set_alpha_to_one:
self.final_alpha_cumprod = torch.tensor(1.0)
self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0
else:
self.final_alpha_cumprod = self.alphas_cumprod[0]
self.final_sigma = self.sigmas[0] if prediction_type == "velocity" else None
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
@@ -153,6 +180,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self.variance_type = variance_type
self.prediction_type = prediction_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
@@ -162,20 +191,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_variance(self, timestep, prev_timestep):
def _get_variance(self, timestep, prev_timestep, eta=0):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
if self.variance_type == "fixed":
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
elif self.variance_type == "v_diffusion":
# If eta > 0, adjust the scaling factor for the predicted noise
# downward according to the amount of additional noise to add
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma
if eta:
numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt()
else:
numerator = 0
denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt()
ddim_sigma = (numerator * denominator).clamp(min=1.0e-7)
variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt()
return variance
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
@@ -240,14 +280,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - pred_noise_t -> e_theta(x_t, timestep)
# - pred_original_sample -> f_theta(x_t, timestep) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
# 1. get previous step value (=timestep-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
@@ -258,7 +298,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
eps = torch.tensor(1)
elif self.prediction_type == "sample":
pred_original_sample = model_output
eps = torch.tensor(1)
elif self.prediction_type == "velocity":
# v_t = alpha_t * epsilon - sigma_t * x
# need to merge the PRs for sigma to be available in DDPM
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep]
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
)
# 4. Clip "predicted x_0"
if self.config.clip_sample:
@@ -266,7 +320,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = self._get_variance(timestep, prev_timestep)
variance = self._get_variance(timestep, prev_timestep, eta)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
@@ -274,10 +328,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
if self.prediction_type == "epsilon":
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
else:
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
prev_sample = pred_original_sample * alpha_prev + eps * variance
if eta > 0:
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
@@ -300,7 +358,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
prev_sample = prev_sample + variance
if not return_dict:
return (prev_sample,)
@@ -312,6 +369,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.variance_type == "v_diffusion":
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
z_t = alpha * original_samples + sigma * noise
return z_t
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
@@ -331,3 +392,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self):
return self.config.num_train_timesteps
def get_alpha_sigma(self, sample, timesteps, device):
alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device)
sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device)
return alpha, sigma

View File

@@ -23,7 +23,7 @@ import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, expand_to_shape
@dataclass
@@ -99,9 +99,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
predict_epsilon (`bool`):
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
https://imagen.research.google/video/paper.pdf)
predict_epsilon (`bool`, default `True`):
deprecated flag (removing v0.10.0) for epsilon vs. direct sample prediction.
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
@@ -116,6 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
trained_betas: Optional[np.ndarray] = None,
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
predict_epsilon: bool = True,
):
if trained_betas is not None:
@@ -139,7 +143,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
@@ -149,6 +154,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type
self.prediction_type = prediction_type
message = (
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
)
deprecate("predict_epsilon", "0.10.0", message)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
@@ -179,14 +191,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
# x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
if variance_type is None:
variance_type = self.config.variance_type
@@ -199,17 +211,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.exp(0.5 * variance)
elif variance_type == "fixed_large":
variance = self.betas[t]
variance = self.betas[timestep]
elif variance_type == "fixed_large_log":
# Glide max_log
variance = torch.log(self.betas[t])
variance = torch.log(self.betas[timestep])
elif variance_type == "learned":
return predicted_variance
elif variance_type == "learned_range":
min_log = variance
max_log = self.betas[t]
max_log = self.betas[timestep]
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
elif variance_type == "v_diffusion":
variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t))
return variance
@@ -240,9 +254,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
returning a tuple, the first element is the sample tensor.
"""
if self.variance_type == "v_diffusion":
assert self.prediction_type == "velocity", "Need to use v prediction with v_diffusion"
message = (
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
)
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
@@ -250,34 +266,46 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
new_config["predict_epsilon"] = predict_epsilon
self._internal_dict = FrozenDict(new_config)
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
if self.prediction_type == "velocity":
# x_recon in p_mean_variance
pred_original_sample = (
sample * self.sqrt_alphas_cumprod[timestep]
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
)
# not check on predict_epsilon for depreciation flag above
elif self.prediction_type == "sample" or not self.config.predict_epsilon:
pred_original_sample = model_output
elif self.prediction_type == "epsilon" or self.config.predict_epsilon:
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
)
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
@@ -285,7 +313,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# 6. Add noise
variance = 0
if t > 0:
if timestep > 0:
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
@@ -296,9 +324,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise
elif self.variance_type == "v_diffusion":
variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * variance_noise
else:
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
variance = (
self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5
) * variance_noise
pred_prev_sample = pred_prev_sample + variance
@@ -313,6 +345,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
if self.variance_type == "v_diffusion":
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
z_t = alpha * original_samples + sigma * noise
return z_t
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
timesteps = timesteps.to(original_samples.device)
@@ -332,3 +369,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __len__(self):
return self.config.num_train_timesteps
def get_alpha_sigma(self, sample, timesteps, device):
alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device)
sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device)
return alpha, sigma

View File

@@ -21,7 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@@ -88,9 +88,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
sampling, and `solver_order=3` for unconditional sampling.
predict_epsilon (`bool`, default `True`):
we currently support both the noise prediction model and the data prediction model. If the model predicts
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
`predict_epsilon` to `False`.
deprecated flag (removing v0.10.0); we currently support both the noise prediction model and the data
prediction model. If the model predicts the noise / epsilon, set `predict_epsilon` to `True`. If the model
predicts the data / x0 directly, set `predict_epsilon` to `False`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
@@ -128,6 +132,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
predict_epsilon: bool = True,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
@@ -174,6 +179,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
if prediction_type not in ["epsilon", "sample"]:
raise ValueError(
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
)
message = (
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
)
deprecate("predict_epsilon", "0.10.0", message)
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -221,11 +237,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
# DPM-Solver++ needs to solve an integral of the data prediction model.
if self.config.algorithm_type == "dpmsolver++":
if self.config.predict_epsilon:
if self.config.prediction_type == "epsilon":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
x0_pred = (sample - sigma_t * model_output) / alpha_t
else:
elif self.config.prediction_type == "sample":
x0_pred = model_output
else:
raise ValueError(
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
)
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = torch.quantile(

View File

@@ -152,3 +152,14 @@ class SchedulerMixin:
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
]
return compatible_classes
def expand_to_shape(input, timesteps, shape, device):
"""
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
"""
out = torch.gather(input.to(device), 0, timesteps.to(device))
reshape = [shape[0]] + [1] * (len(shape) - 1)
out = out.reshape(*reshape)
return out

View File

@@ -599,9 +599,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_predict_epsilon(self):
for predict_epsilon in [True, False]:
self.check_over_configs(predict_epsilon=predict_epsilon)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "velocity"]:
self.check_over_configs(prediction_type=prediction_type)
def test_deprecated_epsilon(self):
deprecate("remove this test", "0.10.0", "remove")
@@ -613,7 +613,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
time_step = 4
scheduler = scheduler_class(**scheduler_config)
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config)
kwargs = {}
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
@@ -728,6 +728,10 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for schedule in ["linear", "squaredcos_cap_v2"]:
self.check_over_configs(beta_schedule=schedule)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "velocity"]:
self.check_over_configs(prediction_type=prediction_type)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)