Compare commits

...

7 Commits

Author SHA1 Message Date
apolinário
ff1012f8cf Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 23:34:30 +02:00
apolinário
25bc77d8f8 Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2025-04-21 23:34:24 +02:00
apolinário
9c35a89921 Swap order 2025-04-21 23:31:04 +02:00
github-actions[bot]
32d9aef997 Apply style fixes 2025-04-21 15:42:31 +00:00
apolinário
9edc5beddc Use config value directly 2025-04-21 17:40:13 +02:00
github-actions[bot]
f87956e9cf Apply style fixes 2025-04-19 19:33:27 +00:00
apolinário
690adb5bd9 Add stochastic sampling to FlowMatchEulerDiscreteScheduler
This PR adds stochastic sampling to FlowMatchEulerDiscreteScheduler based on b1aeddd7cc  ltx_video/schedulers/rf.py
2025-04-19 19:48:16 +02:00

View File

@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
time_shift_type (`str`, defaults to "exponential"):
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
stochastic_sampling (`bool`, defaults to False):
Whether to use stochastic sampling.
"""
_compatibles = []
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
time_shift_type: str = "exponential",
stochastic_sampling: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
lower_sigmas = lower_mask * sigmas
lower_sigmas, _ = lower_sigmas.max(dim=0)
dt = (per_token_sigmas - lower_sigmas)[..., None]
current_sigma = per_token_sigmas[..., None]
next_sigma = lower_sigmas[..., None]
dt = current_sigma - next_sigma
else:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
sigma_idx = self.step_index
sigma = self.sigmas[sigma_idx]
sigma_next = self.sigmas[sigma_idx + 1]
current_sigma = sigma
next_sigma = sigma_next
dt = sigma_next - sigma
prev_sample = sample + dt * model_output
if self.config.stochastic_sampling:
x0 = sample - current_sigma * model_output
noise = torch.randn_like(sample)
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
else:
prev_sample = sample + dt * model_output
# upon completion increase step index by one
self._step_index += 1