mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-16 01:14:47 +08:00
Compare commits
6 Commits
fp8-note-t
...
cogvideox-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bf39613fb | ||
|
|
8e78c9d1d1 | ||
|
|
b7dd6ba4f1 | ||
|
|
3ae6094d1a | ||
|
|
fb0ebbb731 | ||
|
|
5687dc6d39 |
@@ -231,7 +231,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
# TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
@@ -251,8 +253,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
|
||||
|
||||
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
|
||||
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -338,6 +344,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device)
|
||||
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -402,8 +410,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
|
||||
|
||||
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
|
||||
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
|
||||
@@ -228,11 +228,17 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
# TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
|
||||
|
||||
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
|
||||
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -301,6 +307,8 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(device)
|
||||
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -365,8 +373,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
|
||||
|
||||
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
|
||||
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
|
||||
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
|
||||
Reference in New Issue
Block a user