mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-10 06:24:19 +08:00
Compare commits
20 Commits
sf-comfy-l
...
v_predicti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79ec3a8a39 | ||
|
|
da5e677c18 | ||
|
|
b70f6cd5e0 | ||
|
|
66951ec084 | ||
|
|
172b242c2a | ||
|
|
e701a97838 | ||
|
|
c1a0584213 | ||
|
|
3adf87b2d9 | ||
|
|
5a509dbedd | ||
|
|
e39198306b | ||
|
|
11362ae5d2 | ||
|
|
56164f56fb | ||
|
|
8fe2ff4b16 | ||
|
|
f00d896a1e | ||
|
|
ac6be90a71 | ||
|
|
4c6850473d | ||
|
|
3eb2593d9a | ||
|
|
7eb4bfae6c | ||
|
|
b7d0c1e84a | ||
|
|
798263f629 |
@@ -194,16 +194,28 @@ def parse_args():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--predict_epsilon",
|
"--prediction_type",
|
||||||
action="store_true",
|
type=str,
|
||||||
default=True,
|
default="epsilon",
|
||||||
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
|
help=(
|
||||||
|
"Whether the model should predict the 'epsilon'/noise error, directly the reconstructed image 'x0', or the"
|
||||||
|
" velocity of the ODE 'velocity'."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
|
||||||
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your training with `--prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
|
||||||
|
)
|
||||||
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=args)
|
||||||
|
if predict_epsilon:
|
||||||
|
args.prediction_type = "epsilon"
|
||||||
|
|
||||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||||
args.local_rank = env_local_rank
|
args.local_rank = env_local_rank
|
||||||
@@ -256,13 +268,13 @@ def main(args):
|
|||||||
"UpBlock2D",
|
"UpBlock2D",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
accepts_predict_epsilon = "predict_epsilon" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||||
|
|
||||||
if accepts_predict_epsilon:
|
if accepts_prediction_type:
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=args.ddpm_num_steps,
|
num_train_timesteps=args.ddpm_num_steps,
|
||||||
beta_schedule=args.ddpm_beta_schedule,
|
beta_schedule=args.ddpm_beta_schedule,
|
||||||
predict_epsilon=args.predict_epsilon,
|
prediction_type=args.prediction_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||||
@@ -365,7 +377,7 @@ def main(args):
|
|||||||
# Predict the noise residual
|
# Predict the noise residual
|
||||||
model_output = model(noisy_images, timesteps).sample
|
model_output = model(noisy_images, timesteps).sample
|
||||||
|
|
||||||
if args.predict_epsilon:
|
if args.prediction_type == "epsilon":
|
||||||
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
loss = F.mse_loss(model_output, noise) # this could have different weights!
|
||||||
else:
|
else:
|
||||||
alpha_t = _extract_into_tensor(
|
alpha_t = _extract_into_tensor(
|
||||||
|
|||||||
227
examples/v_prediction/train_butterflies.py
Normal file
227
examples/v_prediction/train_butterflies.py
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
import glob
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from accelerate import Accelerator
|
||||||
|
from datasets import load_dataset
|
||||||
|
from diffusers import DDIMPipeline, DDIMScheduler, DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||||
|
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||||
|
from diffusers.optimization import get_cosine_schedule_with_warmup
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import transforms
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
image_size = 128 # the generated image resolution
|
||||||
|
train_batch_size = 16
|
||||||
|
eval_batch_size = 16 # how many images to sample during evaluation
|
||||||
|
num_epochs = 50
|
||||||
|
gradient_accumulation_steps = 1
|
||||||
|
learning_rate = 5e-5
|
||||||
|
lr_warmup_steps = 500
|
||||||
|
save_image_epochs = 10
|
||||||
|
save_model_epochs = 30
|
||||||
|
mixed_precision = "fp16" # `no` for float32, `fp16` for automatic mixed precision
|
||||||
|
output_dir = "ddim-butterflies-128-v-diffusion" # the model namy locally and on the HF Hub
|
||||||
|
|
||||||
|
push_to_hub = False # whether to upload the saved model to the HF Hub
|
||||||
|
hub_private_repo = False
|
||||||
|
overwrite_output_dir = True # overwrite the old model when re-running the notebook
|
||||||
|
seed = 0
|
||||||
|
|
||||||
|
|
||||||
|
config = TrainingConfig()
|
||||||
|
|
||||||
|
|
||||||
|
config.dataset_name = "huggan/smithsonian_butterflies_subset"
|
||||||
|
dataset = load_dataset(config.dataset_name, split="train")
|
||||||
|
|
||||||
|
|
||||||
|
preprocess = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((config.image_size, config.image_size)),
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def transform(examples):
|
||||||
|
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
|
||||||
|
return {"images": images}
|
||||||
|
|
||||||
|
|
||||||
|
dataset.set_transform(transform)
|
||||||
|
|
||||||
|
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)
|
||||||
|
|
||||||
|
|
||||||
|
model = UNet2DModel(
|
||||||
|
sample_size=config.image_size, # the target image resolution
|
||||||
|
in_channels=3, # the number of input channels, 3 for RGB images
|
||||||
|
out_channels=3, # the number of output channels
|
||||||
|
layers_per_block=2, # how many ResNet layers to use per UNet block
|
||||||
|
block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
|
||||||
|
down_block_types=(
|
||||||
|
"DownBlock2D", # a regular ResNet downsampling block
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
up_block_types=(
|
||||||
|
"UpBlock2D", # a regular ResNet upsampling block
|
||||||
|
"AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if config.output_dir.startswith("ddpm"):
|
||||||
|
noise_scheduler = DDPMScheduler(
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
beta_schedule="squaredcos_cap_v2",
|
||||||
|
variance_type="v_diffusion",
|
||||||
|
prediction_type="velocity",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
noise_scheduler = DDIMScheduler(
|
||||||
|
num_train_timesteps=1000,
|
||||||
|
beta_schedule="squaredcos_cap_v2",
|
||||||
|
variance_type="v_diffusion",
|
||||||
|
prediction_type="velocity",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=config.lr_warmup_steps,
|
||||||
|
num_training_steps=(len(train_dataloader) * config.num_epochs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_grid(images, rows, cols):
|
||||||
|
w, h = images[0].size
|
||||||
|
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||||
|
for i, image in enumerate(images):
|
||||||
|
grid.paste(image, box=(i % cols * w, i // cols * h))
|
||||||
|
return grid
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(config, epoch, pipeline):
|
||||||
|
# Sample some images from random noise (this is the backward diffusion process).
|
||||||
|
# The default pipeline output type is `List[PIL.Image]`
|
||||||
|
images = pipeline(
|
||||||
|
batch_size=config.eval_batch_size,
|
||||||
|
generator=torch.manual_seed(config.seed),
|
||||||
|
).images
|
||||||
|
|
||||||
|
# Make a grid out of the images
|
||||||
|
image_grid = make_grid(images, rows=4, cols=4)
|
||||||
|
|
||||||
|
# Save the images
|
||||||
|
test_dir = os.path.join(config.output_dir, "samples")
|
||||||
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
|
image_grid.save(f"{test_dir}/{epoch:04d}.png")
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
|
||||||
|
# Initialize accelerator and tensorboard logging
|
||||||
|
accelerator = Accelerator(
|
||||||
|
mixed_precision=config.mixed_precision,
|
||||||
|
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||||
|
log_with="tensorboard",
|
||||||
|
logging_dir=os.path.join(config.output_dir, "logs"),
|
||||||
|
)
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if config.push_to_hub:
|
||||||
|
repo = init_git_repo(config, at_init=True)
|
||||||
|
accelerator.init_trackers("train_example")
|
||||||
|
|
||||||
|
# Prepare everything
|
||||||
|
# There is no specific order to remember, you just need to unpack the
|
||||||
|
# objects in the same order you gave them to the prepare method.
|
||||||
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
|
model, optimizer, train_dataloader, lr_scheduler
|
||||||
|
)
|
||||||
|
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
if config.output_dir.startswith("ddpm"):
|
||||||
|
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||||
|
else:
|
||||||
|
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||||
|
|
||||||
|
evaluate(config, 0, pipeline)
|
||||||
|
|
||||||
|
# Now you train the model
|
||||||
|
for epoch in range(config.num_epochs):
|
||||||
|
progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
|
||||||
|
progress_bar.set_description(f"Epoch {epoch}")
|
||||||
|
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
clean_images = batch["images"]
|
||||||
|
# Sample noise to add to the images
|
||||||
|
noise = torch.randn(clean_images.shape).to(clean_images.device)
|
||||||
|
bs = clean_images.shape[0]
|
||||||
|
|
||||||
|
# Sample a random timestep for each image
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
|
||||||
|
|
||||||
|
with accelerator.accumulate(model):
|
||||||
|
# Predict the noise residual
|
||||||
|
alpha_t, sigma_t = noise_scheduler.get_alpha_sigma(clean_images, timesteps, accelerator.device)
|
||||||
|
z_t = alpha_t * clean_images + sigma_t * noise
|
||||||
|
noise_pred = model(z_t, timesteps).sample
|
||||||
|
v = alpha_t * noise - sigma_t * clean_images
|
||||||
|
loss = F.mse_loss(noise_pred, v)
|
||||||
|
accelerator.backward(loss)
|
||||||
|
|
||||||
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
progress_bar.update(1)
|
||||||
|
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
accelerator.log(logs, step=global_step)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# After each epoch you optionally sample some demo images with evaluate() and save the model
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if config.output_dir.startswith("ddpm"):
|
||||||
|
pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||||
|
else:
|
||||||
|
pipeline = DDIMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
|
||||||
|
|
||||||
|
if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
|
||||||
|
evaluate(config, epoch, pipeline)
|
||||||
|
|
||||||
|
if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
|
||||||
|
if config.push_to_hub:
|
||||||
|
push_to_hub(config, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=True)
|
||||||
|
else:
|
||||||
|
pipeline.save_pretrained(config.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
|
||||||
|
|
||||||
|
train_loop(*args)
|
||||||
|
|
||||||
|
sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
|
||||||
|
Image.open(sample_images[-1])
|
||||||
@@ -24,7 +24,7 @@ import torch
|
|||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -75,6 +75,18 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor
|
|||||||
return torch.tensor(betas)
|
return torch.tensor(betas)
|
||||||
|
|
||||||
|
|
||||||
|
def t_to_alpha_sigma(num_diffusion_timesteps):
|
||||||
|
"""Returns the scaling factors for the clean image and for the noise, given
|
||||||
|
a timestep."""
|
||||||
|
alphas = torch.cos(
|
||||||
|
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||||
|
)
|
||||||
|
sigmas = torch.sin(
|
||||||
|
torch.tensor([(t / num_diffusion_timesteps) * math.pi / 2 for t in range(num_diffusion_timesteps)])
|
||||||
|
)
|
||||||
|
return alphas, sigmas
|
||||||
|
|
||||||
|
|
||||||
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
"""
|
"""
|
||||||
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
|
||||||
@@ -106,6 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
an offset added to the inference steps. You can use a combination of `offset=1` and
|
an offset added to the inference steps. You can use a combination of `offset=1` and
|
||||||
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
`set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
|
||||||
stable diffusion.
|
stable diffusion.
|
||||||
|
prediction_type (`str`, default `epsilon`, optional):
|
||||||
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
|
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||||
|
https://imagen.research.google/video/paper.pdf)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -121,7 +137,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
trained_betas: Optional[np.ndarray] = None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
clip_sample: bool = True,
|
clip_sample: bool = True,
|
||||||
set_alpha_to_one: bool = True,
|
set_alpha_to_one: bool = True,
|
||||||
|
variance_type: str = "fixed",
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
|
prediction_type: str = "epsilon",
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = torch.from_numpy(trained_betas)
|
self.betas = torch.from_numpy(trained_betas)
|
||||||
@@ -138,14 +157,22 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
self.variance_type = variance_type
|
||||||
self.alphas = 1.0 - self.betas
|
self.alphas = 1.0 - self.betas
|
||||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
|
if prediction_type == "velocity":
|
||||||
|
self.alphas, self.sigmas = t_to_alpha_sigma(num_train_timesteps)
|
||||||
|
|
||||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||||
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
||||||
# whether we use the final alpha of the "non-previous" one.
|
# whether we use the final alpha of the "non-previous" one.
|
||||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
if set_alpha_to_one:
|
||||||
|
self.final_alpha_cumprod = torch.tensor(1.0)
|
||||||
|
self.final_sigma = torch.tensor(0.0) # TODO rename set_alpha_to_one for something general with sigma=0
|
||||||
|
else:
|
||||||
|
self.final_alpha_cumprod = self.alphas_cumprod[0]
|
||||||
|
self.final_sigma = self.sigmas[0] if prediction_type == "velocity" else None
|
||||||
|
|
||||||
# standard deviation of the initial noise distribution
|
# standard deviation of the initial noise distribution
|
||||||
self.init_noise_sigma = 1.0
|
self.init_noise_sigma = 1.0
|
||||||
@@ -153,6 +180,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# setable values
|
# setable values
|
||||||
self.num_inference_steps = None
|
self.num_inference_steps = None
|
||||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||||
|
self.variance_type = variance_type
|
||||||
|
self.prediction_type = prediction_type
|
||||||
|
|
||||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
@@ -162,20 +191,31 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Args:
|
Args:
|
||||||
sample (`torch.FloatTensor`): input sample
|
sample (`torch.FloatTensor`): input sample
|
||||||
timestep (`int`, optional): current timestep
|
timestep (`int`, optional): current timestep
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`: scaled input sample
|
`torch.FloatTensor`: scaled input sample
|
||||||
"""
|
"""
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def _get_variance(self, timestep, prev_timestep):
|
def _get_variance(self, timestep, prev_timestep, eta=0):
|
||||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||||
|
|
||||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
if self.variance_type == "fixed":
|
||||||
|
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||||
|
elif self.variance_type == "v_diffusion":
|
||||||
|
# If eta > 0, adjust the scaling factor for the predicted noise
|
||||||
|
# downward according to the amount of additional noise to add
|
||||||
|
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
sigma_prev = self.sigmas[prev_timestep] if prev_timestep >= 0 else self.final_sigma
|
||||||
|
if eta:
|
||||||
|
numerator = eta * (sigma_prev**2 / self.sigmas[timestep] ** 2).clamp(min=1.0e-7).sqrt()
|
||||||
|
else:
|
||||||
|
numerator = 0
|
||||||
|
denominator = (1 - self.alphas[timestep] ** 2 / alpha_prev**2).clamp(min=1.0e-7).sqrt()
|
||||||
|
ddim_sigma = (numerator * denominator).clamp(min=1.0e-7)
|
||||||
|
variance = (sigma_prev**2 - ddim_sigma**2).clamp(min=1.0e-7).sqrt()
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
@@ -240,14 +280,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# Ideally, read DDIM paper in-detail understanding
|
# Ideally, read DDIM paper in-detail understanding
|
||||||
|
|
||||||
# Notation (<variable name> -> <name in paper>
|
# Notation (<variable name> -> <name in paper>
|
||||||
# - pred_noise_t -> e_theta(x_t, t)
|
# - pred_noise_t -> e_theta(x_t, timestep)
|
||||||
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
# - pred_original_sample -> f_theta(x_t, timestep) or x_0
|
||||||
# - std_dev_t -> sigma_t
|
# - std_dev_t -> sigma_t
|
||||||
# - eta -> η
|
# - eta -> η
|
||||||
# - pred_sample_direction -> "direction pointing to x_t"
|
# - pred_sample_direction -> "direction pointing to x_t"
|
||||||
# - pred_prev_sample -> "x_t-1"
|
# - pred_prev_sample -> "x_t-1"
|
||||||
|
|
||||||
# 1. get previous step value (=t-1)
|
# 1. get previous step value (=timestep-1)
|
||||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||||
|
|
||||||
# 2. compute alphas, betas
|
# 2. compute alphas, betas
|
||||||
@@ -258,7 +298,21 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 3. compute predicted original sample from predicted noise also called
|
# 3. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
if self.prediction_type == "epsilon":
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
eps = torch.tensor(1)
|
||||||
|
elif self.prediction_type == "sample":
|
||||||
|
pred_original_sample = model_output
|
||||||
|
eps = torch.tensor(1)
|
||||||
|
elif self.prediction_type == "velocity":
|
||||||
|
# v_t = alpha_t * epsilon - sigma_t * x
|
||||||
|
# need to merge the PRs for sigma to be available in DDPM
|
||||||
|
pred_original_sample = sample * self.alphas[timestep] - model_output * self.sigmas[timestep]
|
||||||
|
eps = model_output * self.alphas[timestep] + sample * self.sigmas[timestep]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
|
||||||
|
)
|
||||||
|
|
||||||
# 4. Clip "predicted x_0"
|
# 4. Clip "predicted x_0"
|
||||||
if self.config.clip_sample:
|
if self.config.clip_sample:
|
||||||
@@ -266,7 +320,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||||
variance = self._get_variance(timestep, prev_timestep)
|
variance = self._get_variance(timestep, prev_timestep, eta)
|
||||||
std_dev_t = eta * variance ** (0.5)
|
std_dev_t = eta * variance ** (0.5)
|
||||||
|
|
||||||
if use_clipped_model_output:
|
if use_clipped_model_output:
|
||||||
@@ -274,10 +328,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
|
|
||||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
if self.prediction_type == "epsilon":
|
||||||
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
||||||
|
|
||||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + eps * pred_sample_direction
|
||||||
|
else:
|
||||||
|
alpha_prev = self.alphas[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||||
|
prev_sample = pred_original_sample * alpha_prev + eps * variance
|
||||||
|
|
||||||
if eta > 0:
|
if eta > 0:
|
||||||
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
||||||
@@ -300,7 +358,6 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
variance = self._get_variance(timestep, prev_timestep) ** (0.5) * eta * variance_noise
|
||||||
|
|
||||||
prev_sample = prev_sample + variance
|
prev_sample = prev_sample + variance
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (prev_sample,)
|
return (prev_sample,)
|
||||||
|
|
||||||
@@ -312,6 +369,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
if self.variance_type == "v_diffusion":
|
||||||
|
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
|
||||||
|
z_t = alpha * original_samples + sigma * noise
|
||||||
|
return z_t
|
||||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
@@ -331,3 +392,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
|
def get_alpha_sigma(self, sample, timesteps, device):
|
||||||
|
alpha = expand_to_shape(self.alphas, timesteps, sample.shape, device)
|
||||||
|
sigma = expand_to_shape(self.sigmas, timesteps, sample.shape, device)
|
||||||
|
return alpha, sigma
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import torch
|
|||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate
|
||||||
from .scheduling_utils import SchedulerMixin
|
from .scheduling_utils import SchedulerMixin, expand_to_shape
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -99,9 +99,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||||
clip_sample (`bool`, default `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||||
predict_epsilon (`bool`):
|
prediction_type (`str`, default `epsilon`, optional):
|
||||||
optional flag to use when the model predicts the noise (epsilon), or the samples instead of the noise.
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
|
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||||
|
https://imagen.research.google/video/paper.pdf)
|
||||||
|
predict_epsilon (`bool`, default `True`):
|
||||||
|
deprecated flag (removing v0.10.0) for epsilon vs. direct sample prediction.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||||
@@ -116,6 +119,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
trained_betas: Optional[np.ndarray] = None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
variance_type: str = "fixed_small",
|
variance_type: str = "fixed_small",
|
||||||
clip_sample: bool = True,
|
clip_sample: bool = True,
|
||||||
|
prediction_type: str = "epsilon",
|
||||||
predict_epsilon: bool = True,
|
predict_epsilon: bool = True,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
@@ -139,7 +143,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
self.alphas = 1.0 - self.betas
|
self.alphas = 1.0 - self.betas
|
||||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
self.one = torch.tensor(1.0)
|
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
|
||||||
|
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
|
||||||
|
|
||||||
# standard deviation of the initial noise distribution
|
# standard deviation of the initial noise distribution
|
||||||
self.init_noise_sigma = 1.0
|
self.init_noise_sigma = 1.0
|
||||||
@@ -149,6 +154,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||||
|
|
||||||
self.variance_type = variance_type
|
self.variance_type = variance_type
|
||||||
|
self.prediction_type = prediction_type
|
||||||
|
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
|
||||||
|
)
|
||||||
|
deprecate("predict_epsilon", "0.10.0", message)
|
||||||
|
|
||||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
@@ -179,14 +191,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)[::-1].copy()
|
)[::-1].copy()
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
|
||||||
def _get_variance(self, t, predicted_variance=None, variance_type=None):
|
def _get_variance(self, timestep, predicted_variance=None, variance_type=None):
|
||||||
alpha_prod_t = self.alphas_cumprod[t]
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
|
||||||
|
|
||||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
# For timestep > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||||
# and sample from it to get previous sample
|
# and sample from it to get previous sample
|
||||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
# x_{timestep-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
|
||||||
|
|
||||||
if variance_type is None:
|
if variance_type is None:
|
||||||
variance_type = self.config.variance_type
|
variance_type = self.config.variance_type
|
||||||
@@ -199,17 +211,19 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||||
variance = torch.exp(0.5 * variance)
|
variance = torch.exp(0.5 * variance)
|
||||||
elif variance_type == "fixed_large":
|
elif variance_type == "fixed_large":
|
||||||
variance = self.betas[t]
|
variance = self.betas[timestep]
|
||||||
elif variance_type == "fixed_large_log":
|
elif variance_type == "fixed_large_log":
|
||||||
# Glide max_log
|
# Glide max_log
|
||||||
variance = torch.log(self.betas[t])
|
variance = torch.log(self.betas[timestep])
|
||||||
elif variance_type == "learned":
|
elif variance_type == "learned":
|
||||||
return predicted_variance
|
return predicted_variance
|
||||||
elif variance_type == "learned_range":
|
elif variance_type == "learned_range":
|
||||||
min_log = variance
|
min_log = variance
|
||||||
max_log = self.betas[t]
|
max_log = self.betas[timestep]
|
||||||
frac = (predicted_variance + 1) / 2
|
frac = (predicted_variance + 1) / 2
|
||||||
variance = frac * max_log + (1 - frac) * min_log
|
variance = frac * max_log + (1 - frac) * min_log
|
||||||
|
elif variance_type == "v_diffusion":
|
||||||
|
variance = torch.log(self.betas[timestep] * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t))
|
||||||
|
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
@@ -240,9 +254,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
returning a tuple, the first element is the sample tensor.
|
returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if self.variance_type == "v_diffusion":
|
||||||
|
assert self.prediction_type == "velocity", "Need to use v prediction with v_diffusion"
|
||||||
message = (
|
message = (
|
||||||
"Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler ="
|
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||||
" DDPMScheduler.from_pretrained(<model_id>, predict_epsilon=True)`."
|
" DDPMScheduler.from_config(<model_id>, prediction_type=epsilon)`."
|
||||||
)
|
)
|
||||||
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs)
|
||||||
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon:
|
||||||
@@ -250,34 +266,46 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
new_config["predict_epsilon"] = predict_epsilon
|
new_config["predict_epsilon"] = predict_epsilon
|
||||||
self._internal_dict = FrozenDict(new_config)
|
self._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
t = timestep
|
|
||||||
|
|
||||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||||
else:
|
else:
|
||||||
predicted_variance = None
|
predicted_variance = None
|
||||||
|
|
||||||
# 1. compute alphas, betas
|
# 1. compute alphas, betas
|
||||||
alpha_prod_t = self.alphas_cumprod[t]
|
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else torch.tensor(1.0)
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
beta_prod_t = 1 - alpha_prod_t
|
||||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||||
|
|
||||||
# 2. compute predicted original sample from predicted noise also called
|
# 2. compute predicted original sample from predicted noise also called
|
||||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
if self.config.predict_epsilon:
|
if self.prediction_type == "velocity":
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
# x_recon in p_mean_variance
|
||||||
else:
|
pred_original_sample = (
|
||||||
|
sample * self.sqrt_alphas_cumprod[timestep]
|
||||||
|
- model_output * self.sqrt_one_minus_alphas_cumprod[timestep]
|
||||||
|
)
|
||||||
|
|
||||||
|
# not check on predict_epsilon for depreciation flag above
|
||||||
|
elif self.prediction_type == "sample" or not self.config.predict_epsilon:
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
|
|
||||||
|
elif self.prediction_type == "epsilon" or self.config.predict_epsilon:
|
||||||
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `velocity`"
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Clip "predicted x_0"
|
# 3. Clip "predicted x_0"
|
||||||
if self.config.clip_sample:
|
if self.config.clip_sample:
|
||||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||||
|
|
||||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[t]) / beta_prod_t
|
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
|
||||||
current_sample_coeff = self.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||||
|
|
||||||
# 5. Compute predicted previous sample µ_t
|
# 5. Compute predicted previous sample µ_t
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
@@ -285,7 +313,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 6. Add noise
|
# 6. Add noise
|
||||||
variance = 0
|
variance = 0
|
||||||
if t > 0:
|
if timestep > 0:
|
||||||
device = model_output.device
|
device = model_output.device
|
||||||
if device.type == "mps":
|
if device.type == "mps":
|
||||||
# randn does not work reproducibly on mps
|
# randn does not work reproducibly on mps
|
||||||
@@ -296,9 +324,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||||
)
|
)
|
||||||
if self.variance_type == "fixed_small_log":
|
if self.variance_type == "fixed_small_log":
|
||||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
variance = self._get_variance(timestep, predicted_variance=predicted_variance) * variance_noise
|
||||||
|
elif self.variance_type == "v_diffusion":
|
||||||
|
variance = torch.exp(0.5 * self._get_variance(timestep, predicted_variance)) * variance_noise
|
||||||
else:
|
else:
|
||||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
variance = (
|
||||||
|
self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5
|
||||||
|
) * variance_noise
|
||||||
|
|
||||||
pred_prev_sample = pred_prev_sample + variance
|
pred_prev_sample = pred_prev_sample + variance
|
||||||
|
|
||||||
@@ -313,6 +345,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
if self.variance_type == "v_diffusion":
|
||||||
|
alpha, sigma = self.get_alpha_sigma(original_samples, timesteps, original_samples.device)
|
||||||
|
z_t = alpha * original_samples + sigma * noise
|
||||||
|
return z_t
|
||||||
|
|
||||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||||
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
@@ -332,3 +369,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|
||||||
|
def get_alpha_sigma(self, sample, timesteps, device):
|
||||||
|
alpha = expand_to_shape(self.sqrt_alphas_cumprod, timesteps, sample.shape, device)
|
||||||
|
sigma = expand_to_shape(self.sqrt_one_minus_alphas_cumprod, timesteps, sample.shape, device)
|
||||||
|
return alpha, sigma
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -88,9 +88,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided
|
||||||
sampling, and `solver_order=3` for unconditional sampling.
|
sampling, and `solver_order=3` for unconditional sampling.
|
||||||
predict_epsilon (`bool`, default `True`):
|
predict_epsilon (`bool`, default `True`):
|
||||||
we currently support both the noise prediction model and the data prediction model. If the model predicts
|
deprecated flag (removing v0.10.0); we currently support both the noise prediction model and the data
|
||||||
the noise / epsilon, set `predict_epsilon` to `True`. If the model predicts the data / x0 directly, set
|
prediction model. If the model predicts the noise / epsilon, set `predict_epsilon` to `True`. If the model
|
||||||
`predict_epsilon` to `False`.
|
predicts the data / x0 directly, set `predict_epsilon` to `False`.
|
||||||
|
prediction_type (`str`, default `epsilon`, optional):
|
||||||
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
|
process), `sample` (directly predicting the noisy sample`) or `velocity` (see section 2.4
|
||||||
|
https://imagen.research.google/video/paper.pdf)
|
||||||
thresholding (`bool`, default `False`):
|
thresholding (`bool`, default `False`):
|
||||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||||
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to
|
||||||
@@ -128,6 +132,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
beta_schedule: str = "linear",
|
beta_schedule: str = "linear",
|
||||||
trained_betas: Optional[np.ndarray] = None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
solver_order: int = 2,
|
solver_order: int = 2,
|
||||||
|
prediction_type: str = "epsilon",
|
||||||
predict_epsilon: bool = True,
|
predict_epsilon: bool = True,
|
||||||
thresholding: bool = False,
|
thresholding: bool = False,
|
||||||
dynamic_thresholding_ratio: float = 0.995,
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
@@ -174,6 +179,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
|
|
||||||
|
if prediction_type not in ["epsilon", "sample"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
|
||||||
|
)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
"Please make sure to instantiate your scheduler with `prediction_type=epsilon` instead. E.g. `scheduler ="
|
||||||
|
" DDPMScheduler.from_config(<model_id>, prediction_type='epsilon')`."
|
||||||
|
)
|
||||||
|
deprecate("predict_epsilon", "0.10.0", message)
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
"""
|
"""
|
||||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||||
@@ -221,11 +237,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
"""
|
"""
|
||||||
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
# DPM-Solver++ needs to solve an integral of the data prediction model.
|
||||||
if self.config.algorithm_type == "dpmsolver++":
|
if self.config.algorithm_type == "dpmsolver++":
|
||||||
if self.config.predict_epsilon:
|
if self.config.prediction_type == "epsilon":
|
||||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
||||||
else:
|
elif self.config.prediction_type == "sample":
|
||||||
x0_pred = model_output
|
x0_pred = model_output
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Prediction type {self.config.prediction_type} not supported by DPMSolverMultistepScheduler"
|
||||||
|
)
|
||||||
if self.config.thresholding:
|
if self.config.thresholding:
|
||||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
dynamic_max_val = torch.quantile(
|
dynamic_max_val = torch.quantile(
|
||||||
|
|||||||
@@ -152,3 +152,14 @@ class SchedulerMixin:
|
|||||||
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c)
|
||||||
]
|
]
|
||||||
return compatible_classes
|
return compatible_classes
|
||||||
|
|
||||||
|
|
||||||
|
def expand_to_shape(input, timesteps, shape, device):
|
||||||
|
"""
|
||||||
|
Helper indexes a 1D tensor `input` using a 1D index tensor `timesteps`, then reshapes the result to broadcast
|
||||||
|
nicely with `shape`. Useful for parallelizing operations over `shape[0]` number of diffusion steps at once.
|
||||||
|
"""
|
||||||
|
out = torch.gather(input.to(device), 0, timesteps.to(device))
|
||||||
|
reshape = [shape[0]] + [1] * (len(shape) - 1)
|
||||||
|
out = out.reshape(*reshape)
|
||||||
|
return out
|
||||||
|
|||||||
@@ -599,9 +599,9 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||||||
for clip_sample in [True, False]:
|
for clip_sample in [True, False]:
|
||||||
self.check_over_configs(clip_sample=clip_sample)
|
self.check_over_configs(clip_sample=clip_sample)
|
||||||
|
|
||||||
def test_predict_epsilon(self):
|
def test_prediction_type(self):
|
||||||
for predict_epsilon in [True, False]:
|
for prediction_type in ["epsilon", "sample", "velocity"]:
|
||||||
self.check_over_configs(predict_epsilon=predict_epsilon)
|
self.check_over_configs(prediction_type=prediction_type)
|
||||||
|
|
||||||
def test_deprecated_epsilon(self):
|
def test_deprecated_epsilon(self):
|
||||||
deprecate("remove this test", "0.10.0", "remove")
|
deprecate("remove this test", "0.10.0", "remove")
|
||||||
@@ -613,7 +613,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||||||
time_step = 4
|
time_step = 4
|
||||||
|
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
scheduler_eps = scheduler_class(predict_epsilon=False, **scheduler_config)
|
scheduler_eps = scheduler_class(prediction_type="sample", **scheduler_config)
|
||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
@@ -728,6 +728,10 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
|||||||
for schedule in ["linear", "squaredcos_cap_v2"]:
|
for schedule in ["linear", "squaredcos_cap_v2"]:
|
||||||
self.check_over_configs(beta_schedule=schedule)
|
self.check_over_configs(beta_schedule=schedule)
|
||||||
|
|
||||||
|
def test_prediction_type(self):
|
||||||
|
for prediction_type in ["epsilon", "sample", "velocity"]:
|
||||||
|
self.check_over_configs(prediction_type=prediction_type)
|
||||||
|
|
||||||
def test_clip_sample(self):
|
def test_clip_sample(self):
|
||||||
for clip_sample in [True, False]:
|
for clip_sample in [True, False]:
|
||||||
self.check_over_configs(clip_sample=clip_sample)
|
self.check_over_configs(clip_sample=clip_sample)
|
||||||
|
|||||||
Reference in New Issue
Block a user