Compare commits

...

1 Commits

Author SHA1 Message Date
yiyixuxu
03186e1f59 add 2024-12-19 02:20:15 +01:00

View File

@@ -245,6 +245,13 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
num_inference_steps = len(sigmas)
self.num_inference_steps = num_inference_steps
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
@@ -253,14 +260,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_exponential_sigmas:
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
elif self.config.use_beta_sigmas:
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps