mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
20 Commits
sf-comfy-l
...
v_predicti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79ec3a8a39 | ||
|
|
da5e677c18 | ||
|
|
b70f6cd5e0 | ||
|
|
66951ec084 | ||
|
|
172b242c2a | ||
|
|
e701a97838 | ||
|
|
c1a0584213 | ||
|
|
3adf87b2d9 | ||
|
|
5a509dbedd | ||
|
|
e39198306b | ||
|
|
11362ae5d2 | ||
|
|
56164f56fb | ||
|
|
8fe2ff4b16 | ||
|
|
f00d896a1e | ||
|
|
ac6be90a71 | ||
|
|
4c6850473d | ||
|
|
3eb2593d9a | ||
|
|
7eb4bfae6c | ||
|
|
b7d0c1e84a | ||
|
|
798263f629 |
@@ -194,16 +194,28 @@ def parse_args():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict_epsilon",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default="epsilon",
|
||||
help=(
|
||||
"Whether the model should predict the 'epsilon'/noise error, directly the reconstructed image 'x0', or the"
|
||||
" velocity of the ODE 'velocity'."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
message = (
|
||||
"Please make sure to instantiate your training with `--prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=args)
|
||||
if predict_epsilon:
|
||||
args.prediction_type = "epsilon"
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -256,13 +268,13 @@ def main(args):
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
|
||||
if accepts_predict_epsilon:
|
||||
if accepts_prediction_type:
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=args.ddpm_num_steps,
|
||||
beta_schedule=args.ddpm_beta_schedule,
|
||||
predict_epsilon=args.predict_epsilon,
|
||||
prediction_type=args.prediction_type,
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||
@@ -365,7 +377,7 @@ def main(args):
|
||||
# Predict the noise residual
|
||||
model_output = model(noisy_images, timesteps).sample
|
||||
|
||||
if args.predict_epsilon:
|
||||
if args.prediction_type == "epsilon":
|
||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||
else:
|
||||
alpha_t = _extract_into_tensor(
|
||||
|
||||
227
examples/v_prediction/train_butterflies.py
Normal file
227
examples/v_prediction/train_butterflies.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import glob
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_cosine_schedule_with_warmup
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
image_size = 128 # the generated image resolution
|
||||
train_batch_size = 16
|
||||
eval_batch_size = 16 # how many images to sample during evaluation
|
||||
num_epochs = 50
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 5e-5
|
||||
lr_warmup_steps = 500
|
||||
save_image_epochs = 10
|
||||
save_model_epochs = 30
|
||||
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
||||
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
|
||||
|
||||
push_to_hub = False # whether to upload the saved model to the HF Hub
|
||||
hub_private_repo = False
|
||||
overwrite_output_dir = True # overwrite the old model when re-running the notebook
|
||||
seed = 0
|
||||
|
||||
|
||||
config = TrainingConfig()
|
||||
|
||||
|
||||
config.dataset_name = "huggan/smithsonian_butterflies_subset"
|
||||
dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
|
||||
preprocess = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((config.image_size, config.image_size)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def transform(examples):
|
||||
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"images": images}
|
||||
|
||||
|
||||
dataset.set_transform(transform)
|
||||
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
|
||||
|
||||
|
||||
model = UNet2DModel(
|
||||
sample_size=config.image_size, # the target image resolution
|
||||
in_channels=3, # the number of input channels, 3 for RGB images
|
||||
out_channels=3, # the number of output channels
|
||||
layers_per_block=2, # how many ResNet layers to use per UNet block
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
|
||||
down_block_types=(
|
||||
"DownBlock2D", # a regular ResNet downsampling block
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types=(
|
||||
"UpBlock2D", # a regular ResNet upsampling block
|
||||
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
noise_scheduler = DDPMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
variance_type="v_diffusion",
|
||||
prediction_type="velocity",
|
||||
)
|
||||
else:
|
||||
noise_scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
variance_type="v_diffusion",
|
||||
prediction_type="velocity",
|
||||
)
|
||||
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
|
||||
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=config.lr_warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * config.num_epochs),
|
||||
)
|
||||
|
||||
|
||||
def make_grid(images, rows, cols):
|
||||
w, h = images[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
for i, image in enumerate(images):
|
||||
grid.paste(image, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
def evaluate(config, epoch, pipeline):
|
||||
# Sample some images from random noise (this is the backward diffusion process).
|
||||
# The default pipeline output type is `List[PIL.Image]`
|
||||
images = pipeline(
|
||||
batch_size=config.eval_batch_size,
|
||||
generator=torch.manual_seed(config.seed),
|
||||
).images
|
||||
|
||||
# Make a grid out of the images
|
||||
image_grid = make_grid(images, rows=4, cols=4)
|
||||
|
||||
# Save the images
|
||||
test_dir = os.path.join(config.output_dir, "samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
||||
|
||||
|
||||
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||
# Initialize accelerator and tensorboard logging
|
||||
accelerator = Accelerator(
|
||||
mixed_precision=config.mixed_precision,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
log_with="tensorboard",
|
||||
logging_dir=os.path.join(config.output_dir, "logs"),
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
if config.push_to_hub:
|
||||
repo = init_git_repo(config, at_init=True)
|
||||
accelerator.init_trackers("train_example")
|
||||
|
||||
# Prepare everything
|
||||
# There is no specific order to remember, you just need to unpack the
|
||||
# objects in the same order you gave them to the prepare method.
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
global_step = 0
|
||||
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
else:
|
||||
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
|
||||
evaluate(config, 0, pipeline)
|
||||
|
||||
# Now you train the model
|
||||
for epoch in range(config.num_epochs):
|
||||
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description(f"Epoch {epoch}")
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
clean_images = batch["images"]
|
||||
# Sample noise to add to the images
|
||||
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||
bs = clean_images.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
|
||||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
|
||||
z_t = alpha_t * clean_images + sigma_t * noise
|
||||
noise_pred = model(z_t, timesteps).sample
|
||||
v = alpha_t * noise - sigma_t * clean_images
|
||||
loss = F.mse_loss(noise_pred, v)
|
||||
accelerator.backward(loss)
|
||||
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
progress_bar.update(1)
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
global_step += 1
|
||||
|
||||
# After each epoch you optionally sample some demo images with evaluate() and save the model
|
||||
if accelerator.is_main_process:
|
||||
if config.output_dir.startswith("ddpm"):
|
||||
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
else:
|
||||
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||
|
||||
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
evaluate(config, epoch, pipeline)
|
||||
|
||||
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||
if config.push_to_hub:
|
||||
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
|
||||
else:
|
||||
pipeline.save_pretrained(config.output_dir)
|
||||
|
||||
|
||||
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
|
||||
|
||||
train_loop(*args)
|
||||
|
||||
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
|
||||
Image.open(sample_images[-1])
|
||||
@@ -24,7 +24,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -75,6 +75,18 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
||||
return torch.tensor(betas)
|
||||
|
||||
|
||||
def t_to_alpha_sigma(num_diffusion_timesteps):
|
||||
"""Returns the scaling factors for the clean image and for the noise, given
|
||||
a timestep."""
|
||||
alphas = torch.cos(
|
||||
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||
)
|
||||
sigmas = torch.sin(
|
||||
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||
)
|
||||
return alphas, sigmas
|
||||
|
||||
|
||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||
@@ -106,6 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||
stable diffusion.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
|
||||
"""
|
||||
|
||||
@@ -121,7 +137,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
variance_type: str = "fixed",
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
**kwargs,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
@@ -138,14 +157,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.variance_type = variance_type
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
if prediction_type == "velocity":
|
||||
self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps)
|
||||
|
||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
if set_alpha_to_one:
|
||||
self.final_alpha_cumprod = torch.tensor(1.0)
|
||||
self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0
|
||||
else:
|
||||
self.final_alpha_cumprod = self.alphas_cumprod[0]
|
||||
self.final_sigma = self.sigmas[0] if prediction_type == "velocity" else None
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
@@ -153,6 +180,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
self.variance_type = variance_type
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -162,20 +191,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
def _get_variance(self, timestep, prev_timestep, eta=0):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
if self.variance_type == "fixed":
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
elif self.variance_type == "v_diffusion":
|
||||
# If eta > 0, adjust the scaling factor for the predicted noise
|
||||
# downward according to the amount of additional noise to add
|
||||
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma
|
||||
if eta:
|
||||
numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt()
|
||||
else:
|
||||
numerator = 0
|
||||
denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt()
|
||||
ddim_sigma = (numerator * denominator).clamp(min=1.0e-7)
|
||||
variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt()
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
@@ -240,14 +280,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||
# - pred_noise_t -> e_theta(x_t, timestep)
|
||||
# - pred_original_sample -> f_theta(x_t, timestep) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
# 1. get previous step value (=timestep-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
@@ -258,7 +298,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
eps = torch.tensor(1)
|
||||
elif self.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
eps = torch.tensor(1)
|
||||
elif self.prediction_type == "velocity":
|
||||
# v_t = alpha_t * epsilon - sigma_t * x
|
||||
# need to merge the PRs for sigma to be available in DDPM
|
||||
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
|
||||
eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
@@ -266,7 +320,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = self._get_variance(timestep, prev_timestep)
|
||||
variance = self._get_variance(timestep, prev_timestep, eta)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
|
||||
if use_clipped_model_output:
|
||||
@@ -274,10 +328,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
|
||||
else:
|
||||
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
prev_sample = pred_original_sample * alpha_prev + eps * variance
|
||||
|
||||
if eta > 0:
|
||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||
@@ -300,7 +358,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
||||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
@@ -312,6 +369,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if self.variance_type == "v_diffusion":
|
||||
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
|
||||
z_t = alpha * original_samples + sigma * noise
|
||||
return z_t
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
@@ -331,3 +392,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def get_alpha_sigma(self, sample, timesteps, device):
|
||||
alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device)
|
||||
sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device)
|
||||
return alpha, sigma
|
||||
|
||||
@@ -23,7 +23,7 @@ import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -99,9 +99,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
predict_epsilon (`bool`):
|
||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
||||
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
deprecated flag (removing v0.10.0) for epsilon vs. direct sample prediction.
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
@@ -116,6 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
predict_epsilon: bool = True,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -139,7 +143,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
@@ -149,6 +154,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.variance_type = variance_type
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
deprecate("predict_epsilon", "0.10.0", message)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
@@ -179,14 +191,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
# x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
|
||||
|
||||
if variance_type is None:
|
||||
variance_type = self.config.variance_type
|
||||
@@ -199,17 +211,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||
variance = torch.exp(0.5 * variance)
|
||||
elif variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
variance = self.betas[timestep]
|
||||
elif variance_type == "fixed_large_log":
|
||||
# Glide max_log
|
||||
variance = torch.log(self.betas[t])
|
||||
variance = torch.log(self.betas[timestep])
|
||||
elif variance_type == "learned":
|
||||
return predicted_variance
|
||||
elif variance_type == "learned_range":
|
||||
min_log = variance
|
||||
max_log = self.betas[t]
|
||||
max_log = self.betas[timestep]
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
elif variance_type == "v_diffusion":
|
||||
variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t))
|
||||
|
||||
return variance
|
||||
|
||||
@@ -240,9 +254,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.variance_type == "v_diffusion":
|
||||
assert self.prediction_type == "velocity", "Need to use v prediction with v_diffusion"
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
||||
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
|
||||
)
|
||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||
@@ -250,34 +266,46 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
new_config["predict_epsilon"] = predict_epsilon
|
||||
self._internal_dict = FrozenDict(new_config)
|
||||
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.predict_epsilon:
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
else:
|
||||
if self.prediction_type == "velocity":
|
||||
# x_recon in p_mean_variance
|
||||
pred_original_sample = (
|
||||
sample * self.sqrt_alphas_cumprod[timestep]
|
||||
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
|
||||
)
|
||||
|
||||
# not check on predict_epsilon for depreciation flag above
|
||||
elif self.prediction_type == "sample" or not self.config.predict_epsilon:
|
||||
pred_original_sample = model_output
|
||||
|
||||
elif self.prediction_type == "epsilon" or self.config.predict_epsilon:
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
|
||||
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
|
||||
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
@@ -285,7 +313,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
if timestep > 0:
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
@@ -296,9 +324,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||
variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise
|
||||
elif self.variance_type == "v_diffusion":
|
||||
variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * variance_noise
|
||||
else:
|
||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
||||
variance = (
|
||||
self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5
|
||||
) * variance_noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
@@ -313,6 +345,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
if self.variance_type == "v_diffusion":
|
||||
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
|
||||
z_t = alpha * original_samples + sigma * noise
|
||||
return z_t
|
||||
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
@@ -332,3 +369,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
def get_alpha_sigma(self, sample, timesteps, device):
|
||||
alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device)
|
||||
sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device)
|
||||
return alpha, sigma
|
||||
|
||||
@@ -21,7 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@@ -88,9 +88,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
predict_epsilon (`bool`, default `True`):
|
||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
||||
`predict_epsilon` to `False`.
|
||||
deprecated flag (removing v0.10.0); we currently support both the noise prediction model and the data
|
||||
prediction model. If the model predicts the noise / epsilon, set `predict_epsilon` to `True`. If the model
|
||||
predicts the data / x0 directly, set `predict_epsilon` to `False`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||
@@ -128,6 +132,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
predict_epsilon: bool = True,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
@@ -174,6 +179,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
if prediction_type not in ["epsilon", "sample"]:
|
||||
raise ValueError(
|
||||
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
|
||||
)
|
||||
|
||||
message = (
|
||||
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
|
||||
)
|
||||
deprecate("predict_epsilon", "0.10.0", message)
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
@@ -221,11 +237,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||
if self.config.algorithm_type == "dpmsolver++":
|
||||
if self.config.predict_epsilon:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||
else:
|
||||
elif self.config.prediction_type == "sample":
|
||||
x0_pred = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
|
||||
)
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = torch.quantile(
|
||||
|
||||
@@ -152,3 +152,14 @@ class SchedulerMixin:
|
||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||
]
|
||||
return compatible_classes
|
||||
|
||||
|
||||
def expand_to_shape(input, timesteps, shape, device):
|
||||
"""
|
||||
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
|
||||
"""
|
||||
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||
out = out.reshape(*reshape)
|
||||
return out
|
||||
|
||||
@@ -599,9 +599,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_predict_epsilon(self):
|
||||
for predict_epsilon in [True, False]:
|
||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "velocity"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_deprecated_epsilon(self):
|
||||
deprecate("remove this test", "0.10.0", "remove")
|
||||
@@ -613,7 +613,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
time_step = 4
|
||||
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
|
||||
scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config)
|
||||
|
||||
kwargs = {}
|
||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||
@@ -728,6 +728,10 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||
self.check_over_configs(beta_schedule=schedule)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "velocity"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
Reference in New Issue
Block a user