Compare commits

...

2 Commits

Author SHA1 Message Date
yiyixuxu
9c112aaaca copies 2023-12-07 07:45:12 +00:00
yiyixuxu
0f348e5405 fix 2023-12-07 07:44:32 +00:00
5 changed files with 50 additions and 5 deletions

View File

@@ -734,7 +734,16 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):

View File

@@ -896,7 +896,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):

View File

@@ -891,7 +891,16 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):

View File

@@ -897,7 +897,16 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):

View File

@@ -828,7 +828,16 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
step_indices = []
for timestep in timesteps:
index_candidates = (schedule_timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(schedule_timesteps) - 1
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
step_indices.append(step_index)
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):