Compare commits

...

2 Commits

Author SHA1 Message Date
Patrick von Platen
866e1dc777 more fixes 2022-12-09 10:38:38 +00:00
Patrick von Platen
03c5ac0603 correct dpm timesteps 2022-12-09 10:36:51 +00:00
3 changed files with 3 additions and 3 deletions

View File

@@ -197,7 +197,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
.round()[::-1][:-1]
.copy()
.astype(np.int64)

View File

@@ -228,7 +228,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the shape of the samples to be generated.
"""
timesteps = (
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
jnp.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
.round()[::-1][:-1]
.astype(jnp.int32)
)

View File

@@ -221,7 +221,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
self.num_inference_steps = num_inference_steps
timesteps = (
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps)
.round()[::-1][:-1]
.copy()
.astype(np.int64)