mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 15:34:17 +08:00
Compare commits
7 Commits
kernelize
...
add-stocha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff1012f8cf | ||
|
|
25bc77d8f8 | ||
|
|
9c35a89921 | ||
|
|
32d9aef997 | ||
|
|
9edc5beddc | ||
|
|
f87956e9cf | ||
|
|
690adb5bd9 |
@@ -80,6 +80,8 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
Whether to use beta sigmas for step sizes in the noise schedule during sampling.
|
||||||
time_shift_type (`str`, defaults to "exponential"):
|
time_shift_type (`str`, defaults to "exponential"):
|
||||||
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear".
|
||||||
|
stochastic_sampling (`bool`, defaults to False):
|
||||||
|
Whether to use stochastic sampling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = []
|
_compatibles = []
|
||||||
@@ -101,6 +103,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
use_exponential_sigmas: Optional[bool] = False,
|
use_exponential_sigmas: Optional[bool] = False,
|
||||||
use_beta_sigmas: Optional[bool] = False,
|
use_beta_sigmas: Optional[bool] = False,
|
||||||
time_shift_type: str = "exponential",
|
time_shift_type: str = "exponential",
|
||||||
|
stochastic_sampling: bool = False,
|
||||||
):
|
):
|
||||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||||
@@ -437,13 +440,25 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
lower_mask = sigmas < per_token_sigmas[None] - 1e-6
|
||||||
lower_sigmas = lower_mask * sigmas
|
lower_sigmas = lower_mask * sigmas
|
||||||
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
lower_sigmas, _ = lower_sigmas.max(dim=0)
|
||||||
dt = (per_token_sigmas - lower_sigmas)[..., None]
|
|
||||||
|
current_sigma = per_token_sigmas[..., None]
|
||||||
|
next_sigma = lower_sigmas[..., None]
|
||||||
|
dt = current_sigma - next_sigma
|
||||||
else:
|
else:
|
||||||
sigma = self.sigmas[self.step_index]
|
sigma_idx = self.step_index
|
||||||
sigma_next = self.sigmas[self.step_index + 1]
|
sigma = self.sigmas[sigma_idx]
|
||||||
|
sigma_next = self.sigmas[sigma_idx + 1]
|
||||||
|
|
||||||
|
current_sigma = sigma
|
||||||
|
next_sigma = sigma_next
|
||||||
dt = sigma_next - sigma
|
dt = sigma_next - sigma
|
||||||
|
|
||||||
prev_sample = sample + dt * model_output
|
if self.config.stochastic_sampling:
|
||||||
|
x0 = sample - current_sigma * model_output
|
||||||
|
noise = torch.randn_like(sample)
|
||||||
|
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||||
|
else:
|
||||||
|
prev_sample = sample + dt * model_output
|
||||||
|
|
||||||
# upon completion increase step index by one
|
# upon completion increase step index by one
|
||||||
self._step_index += 1
|
self._step_index += 1
|
||||||
|
|||||||
Reference in New Issue
Block a user