mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-17 09:54:41 +08:00
Compare commits
9 Commits
debug
...
dpm-mstep-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f238e0d862 | ||
|
|
a85a18c0ce | ||
|
|
515c105040 | ||
|
|
9eeb5e9a1a | ||
|
|
c95b545113 | ||
|
|
670c782cb2 | ||
|
|
a05a13a9ab | ||
|
|
3b886af21b | ||
|
|
8f78025482 |
@@ -181,9 +181,6 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
||||||
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
||||||
|
|
||||||
# standard deviation of the initial noise distribution
|
|
||||||
self.init_noise_sigma = 1.0
|
|
||||||
|
|
||||||
# settings for DPM-Solver
|
# settings for DPM-Solver
|
||||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||||
if algorithm_type == "deis":
|
if algorithm_type == "deis":
|
||||||
@@ -200,9 +197,26 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# setable values
|
# setable values
|
||||||
self.num_inference_steps = None
|
self.num_inference_steps = None
|
||||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||||
|
|
||||||
self.timesteps = torch.from_numpy(timesteps)
|
self.timesteps = torch.from_numpy(timesteps)
|
||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
|
self._step_index = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def init_noise_sigma(self):
|
||||||
|
# standard deviation of the initial noise distribution
|
||||||
|
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||||
|
return self.sigmas.max()
|
||||||
|
|
||||||
|
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||||
|
|
||||||
|
@property
|
||||||
|
def step_index(self):
|
||||||
|
"""
|
||||||
|
TODO: Nice docstring
|
||||||
|
"""
|
||||||
|
return self._step_index
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||||
"""
|
"""
|
||||||
@@ -221,20 +235,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
|
||||||
if self.config.timestep_spacing == "linspace":
|
if self.config.timestep_spacing == "linspace":
|
||||||
timesteps = (
|
timesteps = np.linspace(0, last_timestep - 1, num_inference_steps)[::-1].copy().astype(np.float32)
|
||||||
np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64)
|
|
||||||
)
|
|
||||||
elif self.config.timestep_spacing == "leading":
|
elif self.config.timestep_spacing == "leading":
|
||||||
step_ratio = last_timestep // (num_inference_steps + 1)
|
step_ratio = last_timestep // self.num_inference_steps
|
||||||
# creates integer timesteps by multiplying by ratio
|
# creates integer timesteps by multiplying by ratio
|
||||||
# casting to int to avoid issues when num_inference_step is power of 3
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
||||||
timesteps += self.config.steps_offset
|
timesteps += self.config.steps_offset
|
||||||
elif self.config.timestep_spacing == "trailing":
|
elif self.config.timestep_spacing == "trailing":
|
||||||
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
step_ratio = self.config.num_train_timesteps / num_inference_steps
|
||||||
# creates integer timesteps by multiplying by ratio
|
# creates integer timesteps by multiplying by ratio
|
||||||
# casting to int to avoid issues when num_inference_step is power of 3
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64)
|
timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.float32)
|
||||||
timesteps -= 1
|
timesteps -= 1
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -242,18 +254,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||||
|
log_sigmas = np.log(sigmas)
|
||||||
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||||
|
|
||||||
if self.config.use_karras_sigmas:
|
if self.config.use_karras_sigmas:
|
||||||
log_sigmas = np.log(sigmas)
|
|
||||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||||
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])
|
||||||
timesteps = np.flip(timesteps).copy().astype(np.int64)
|
|
||||||
|
|
||||||
self.sigmas = torch.from_numpy(sigmas)
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||||
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||||
# when num_inference_steps == num_train_timesteps, we can end up with
|
|
||||||
# duplicates in timesteps.
|
|
||||||
_, unique_indices = np.unique(timesteps, return_index=True)
|
|
||||||
timesteps = timesteps[np.sort(unique_indices)]
|
|
||||||
|
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
|
|
||||||
@@ -264,6 +273,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
] * self.config.solver_order
|
] * self.config.solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
|
|
||||||
|
# add an index counter for schedulers that allow duplicated timesteps
|
||||||
|
self._step_index = None
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
@@ -371,13 +383,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
# DPM-Solver and DPM-Solver++ only need the "mean" output.
|
||||||
if self.config.variance_type in ["learned", "learned_range"]:
|
if self.config.variance_type in ["learned", "learned_range"]:
|
||||||
model_output = model_output[:, :3]
|
model_output = model_output[:, :3]
|
||||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
sigma = self.sigmas[self.step_index]
|
||||||
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
x0_pred = sample - sigma * model_output
|
||||||
elif self.config.prediction_type == "sample":
|
elif self.config.prediction_type == "sample":
|
||||||
x0_pred = model_output
|
x0_pred = model_output
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
sigma = self.sigmas[self.step_index]
|
||||||
x0_pred = alpha_t * sample - sigma_t * model_output
|
x0_pred = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
@@ -442,19 +454,24 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
`torch.FloatTensor`:
|
`torch.FloatTensor`:
|
||||||
The sample tensor at the previous timestep.
|
The sample tensor at the previous timestep.
|
||||||
"""
|
"""
|
||||||
lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep]
|
|
||||||
alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
def t_fn(_sigma):
|
||||||
sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep]
|
return -torch.log(_sigma)
|
||||||
h = lambda_t - lambda_s
|
|
||||||
|
# YiYi notes: keep these for now so don't get an error, don't need once fully refactored
|
||||||
|
#alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep]
|
||||||
|
|
||||||
|
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||||
|
h = t_fn(sigma_t) - t_fn(sigma_s)
|
||||||
if self.config.algorithm_type == "dpmsolver++":
|
if self.config.algorithm_type == "dpmsolver++":
|
||||||
x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
|
x_t = (sigma_t / sigma_s) * sample - (torch.exp(-h) - 1.0) * model_output
|
||||||
elif self.config.algorithm_type == "dpmsolver":
|
elif self.config.algorithm_type == "dpmsolver":
|
||||||
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output
|
||||||
elif self.config.algorithm_type == "sde-dpmsolver++":
|
elif self.config.algorithm_type == "sde-dpmsolver++":
|
||||||
assert noise is not None
|
assert noise is not None
|
||||||
x_t = (
|
x_t = (
|
||||||
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
(sigma_t / sigma_s * torch.exp(-h)) * sample
|
||||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output
|
+ (1 - torch.exp(-2.0 * h)) * model_output
|
||||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||||
)
|
)
|
||||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||||
@@ -491,27 +508,34 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
`torch.FloatTensor`:
|
`torch.FloatTensor`:
|
||||||
The sample tensor at the previous timestep.
|
The sample tensor at the previous timestep.
|
||||||
"""
|
"""
|
||||||
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
|
|
||||||
|
def t_fn(_sigma):
|
||||||
|
return -torch.log(_sigma)
|
||||||
|
|
||||||
|
# YiYi notes: keep these for now so don't get an error, not needed once fully refactored
|
||||||
|
#t, s0 = prev_timestep, timestep_list[-1]
|
||||||
|
#alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
||||||
|
|
||||||
|
sigma_t, sigma_s0, sigma_s1 = (
|
||||||
|
self.sigmas[self.step_index + 1],
|
||||||
|
self.sigmas[self.step_index],
|
||||||
|
self.sigmas[self.step_index - 1],
|
||||||
|
)
|
||||||
m0, m1 = model_output_list[-1], model_output_list[-2]
|
m0, m1 = model_output_list[-1], model_output_list[-2]
|
||||||
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
|
|
||||||
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
|
h, h_0 = t_fn(sigma_t) - t_fn(sigma_s0), t_fn(sigma_s0) - t_fn(sigma_s1)
|
||||||
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
|
|
||||||
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
|
|
||||||
r0 = h_0 / h
|
r0 = h_0 / h
|
||||||
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
|
||||||
|
|
||||||
if self.config.algorithm_type == "dpmsolver++":
|
if self.config.algorithm_type == "dpmsolver++":
|
||||||
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
# See https://arxiv.org/abs/2211.01095 for detailed derivations
|
||||||
if self.config.solver_type == "midpoint":
|
if self.config.solver_type == "midpoint":
|
||||||
x_t = (
|
x_t = (sigma_t / sigma_s0) * sample - (torch.exp(-h) - 1.0) * D0 - 0.5 * (torch.exp(-h) - 1.0) * D1
|
||||||
(sigma_t / sigma_s0) * sample
|
|
||||||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
|
||||||
- 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1
|
|
||||||
)
|
|
||||||
elif self.config.solver_type == "heun":
|
elif self.config.solver_type == "heun":
|
||||||
x_t = (
|
x_t = (
|
||||||
(sigma_t / sigma_s0) * sample
|
(sigma_t / sigma_s0) * sample
|
||||||
- (alpha_t * (torch.exp(-h) - 1.0)) * D0
|
- (torch.exp(-h) - 1.0) * D0
|
||||||
+ (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1
|
+ ((torch.exp(-h) - 1.0) / h + 1.0) * D1
|
||||||
)
|
)
|
||||||
elif self.config.algorithm_type == "dpmsolver":
|
elif self.config.algorithm_type == "dpmsolver":
|
||||||
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
# See https://arxiv.org/abs/2206.00927 for detailed derivations
|
||||||
@@ -532,15 +556,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if self.config.solver_type == "midpoint":
|
if self.config.solver_type == "midpoint":
|
||||||
x_t = (
|
x_t = (
|
||||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
+ (1 - torch.exp(-2.0 * h)) * D0
|
||||||
+ 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1
|
+ 0.5 * (1 - torch.exp(-2.0 * h)) * D1
|
||||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||||
)
|
)
|
||||||
elif self.config.solver_type == "heun":
|
elif self.config.solver_type == "heun":
|
||||||
x_t = (
|
x_t = (
|
||||||
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
(sigma_t / sigma_s0 * torch.exp(-h)) * sample
|
||||||
+ (alpha_t * (1 - torch.exp(-2.0 * h))) * D0
|
+ (1 - torch.exp(-2.0 * h)) * D0
|
||||||
+ (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1
|
+ ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0) * D1
|
||||||
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise
|
||||||
)
|
)
|
||||||
elif self.config.algorithm_type == "sde-dpmsolver":
|
elif self.config.algorithm_type == "sde-dpmsolver":
|
||||||
@@ -619,6 +643,23 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
|
def _init_step_index(self, timestep):
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
|
||||||
|
index_candidates = (self.timesteps == timestep).nonzero()
|
||||||
|
|
||||||
|
# The sigma index that is taken for the **very** first `step`
|
||||||
|
# is always the second index (or the last index if there is only 1)
|
||||||
|
# This way we can ensure we don't accidentally skip a sigma in
|
||||||
|
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||||
|
if len(index_candidates) > 1:
|
||||||
|
step_index = index_candidates[1]
|
||||||
|
else:
|
||||||
|
step_index = index_candidates[0]
|
||||||
|
|
||||||
|
self._step_index = step_index.item()
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
@@ -654,19 +695,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(timestep, torch.Tensor):
|
if self.step_index is None:
|
||||||
timestep = timestep.to(self.timesteps.device)
|
self._init_step_index(timestep)
|
||||||
step_index = (self.timesteps == timestep).nonzero()
|
|
||||||
if len(step_index) == 0:
|
prev_timestep = 0 if self.step_index == len(self.timesteps) - 1 else self.timesteps[self.step_index + 1]
|
||||||
step_index = len(self.timesteps) - 1
|
lower_order_final = self.step_index == len(self.timesteps) - 1
|
||||||
else:
|
|
||||||
step_index = step_index.item()
|
|
||||||
prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1]
|
|
||||||
lower_order_final = (
|
|
||||||
(step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15
|
|
||||||
)
|
|
||||||
lower_order_second = (
|
lower_order_second = (
|
||||||
(step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
(self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15
|
||||||
)
|
)
|
||||||
|
|
||||||
model_output = self.convert_model_output(model_output, timestep, sample)
|
model_output = self.convert_model_output(model_output, timestep, sample)
|
||||||
@@ -686,12 +721,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output, timestep, prev_timestep, sample, noise=noise
|
model_output, timestep, prev_timestep, sample, noise=noise
|
||||||
)
|
)
|
||||||
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:
|
||||||
timestep_list = [self.timesteps[step_index - 1], timestep]
|
timestep_list = [self.timesteps[self.step_index - 1], timestep]
|
||||||
prev_sample = self.multistep_dpm_solver_second_order_update(
|
prev_sample = self.multistep_dpm_solver_second_order_update(
|
||||||
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
self.model_outputs, timestep_list, prev_timestep, sample, noise=noise
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep]
|
timestep_list = [self.timesteps[self.step_index - 2], self.timesteps[self.step_index - 1], timestep]
|
||||||
prev_sample = self.multistep_dpm_solver_third_order_update(
|
prev_sample = self.multistep_dpm_solver_third_order_update(
|
||||||
self.model_outputs, timestep_list, prev_timestep, sample
|
self.model_outputs, timestep_list, prev_timestep, sample
|
||||||
)
|
)
|
||||||
@@ -699,24 +734,37 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
if self.lower_order_nums < self.config.solver_order:
|
if self.lower_order_nums < self.config.solver_order:
|
||||||
self.lower_order_nums += 1
|
self.lower_order_nums += 1
|
||||||
|
|
||||||
|
# upon completion increase step index by one
|
||||||
|
self._step_index += 1
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (prev_sample,)
|
return (prev_sample,)
|
||||||
|
|
||||||
return SchedulerOutput(prev_sample=prev_sample)
|
return SchedulerOutput(prev_sample=prev_sample)
|
||||||
|
|
||||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.scale_model_input
|
||||||
|
def scale_model_input(
|
||||||
|
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||||
current timestep.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sample (`torch.FloatTensor`):
|
sample (`torch.FloatTensor`): input sample
|
||||||
The input sample.
|
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
`torch.FloatTensor`:
|
`torch.FloatTensor`:
|
||||||
A scaled input sample.
|
A scaled input sample.
|
||||||
"""
|
"""
|
||||||
|
if self.step_index is None:
|
||||||
|
self._init_step_index(timestep)
|
||||||
|
|
||||||
|
sigma = self.sigmas[self.step_index]
|
||||||
|
|
||||||
|
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||||
|
|
||||||
|
self.is_scale_input_called = True
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
||||||
|
|||||||
@@ -264,10 +264,10 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||||||
|
|
||||||
assert sample.dtype == torch.float16
|
assert sample.dtype == torch.float16
|
||||||
|
|
||||||
def test_unique_timesteps(self, **config):
|
def test_duplicated_timesteps(self, **config):
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
scheduler_config = self.get_scheduler_config(**config)
|
scheduler_config = self.get_scheduler_config(**config)
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
|
scheduler.set_timesteps(scheduler.config.num_train_timesteps)
|
||||||
assert len(scheduler.timesteps.unique()) == scheduler.num_inference_steps
|
assert len(scheduler.timesteps) == scheduler.num_inference_steps
|
||||||
|
|||||||
Reference in New Issue
Block a user