Compare commits

...

6 Commits

Author SHA1 Message Date
Aryan
3bf39613fb update 2024-09-24 13:23:29 +02:00
Aryan
8e78c9d1d1 update 2024-09-24 07:34:53 +02:00
Aryan
b7dd6ba4f1 Merge branch 'main' into cogvideox-profiling 2024-09-24 06:33:21 +02:00
Aryan
3ae6094d1a dump 2024-09-24 06:31:53 +02:00
Aryan
fb0ebbb731 update 2024-09-20 01:04:12 +02:00
Aryan
5687dc6d39 profile cogvideox 2024-09-20 00:00:05 +02:00
2 changed files with 32 additions and 10 deletions

View File

@@ -231,7 +231,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
# TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
"""
@@ -251,8 +253,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return sample
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -338,6 +344,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.alphas_cumprod = self.alphas_cumprod.to(device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
def step(
self,
@@ -402,8 +410,11 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t

View File

@@ -228,11 +228,17 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
# TODO: discuss with YiYi why we have a .copy() here and if it's really needed. I've removed it for now
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].astype(np.int64))
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
@@ -301,6 +307,8 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
)
self.timesteps = torch.from_numpy(timesteps).to(device)
self.alphas_cumprod = self.alphas_cumprod.to(device)
self.final_alpha_cumprod = self.final_alpha_cumprod.to(device)
def step(
self,
@@ -365,8 +373,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep)
safe_prev_timestep = torch.clamp(prev_timestep, min=0)
safe_alpha_prod_t_prev = torch.gather(self.alphas_cumprod, 0, safe_prev_timestep)
alpha_prod_t_prev = torch.where(prev_timestep >= 0, safe_alpha_prod_t_prev, self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t