Compare commits

...

3 Commits

Author SHA1 Message Date
yiyixuxu
df493b90ac draft 2023-11-10 07:18:32 +00:00
yiyixuxu
c2a8afc60c Revert "fix"
This reverts commit 119cf05bfa.
2023-11-10 04:05:44 +00:00
yiyixuxu
119cf05bfa fix 2023-11-10 03:06:48 +00:00
2 changed files with 32 additions and 16 deletions

View File

@@ -551,6 +551,7 @@ class StableDiffusionImg2ImgPipeline(
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
self.scheduler._step_index_init = t_start * self.scheduler.order
return timesteps, num_inference_steps - t_start

View File

@@ -215,6 +215,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.model_outputs = [None] * solver_order
self.lower_order_nums = 0
self._step_index = None
self._step_index_init = None
@property
def step_index(self):
@@ -222,6 +223,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
The index counter for current timestep. It will increae 1 after each scheduler step.
"""
return self._step_index
@property
def step_index_init(self):
"""
the first step_index for denoising loop.
"""
return self._step_index_init
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
@@ -760,23 +768,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
if self.step_index_init is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
index_candidates = (self.timesteps == timestep).nonzero()
index_candidates = (self.timesteps == timestep).nonzero()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
if len(index_candidates) == 0:
step_index = len(self.timesteps) - 1
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
elif len(index_candidates) > 1:
step_index = index_candidates[1].item()
else:
step_index = index_candidates[0].item()
self._step_index_init = step_index
self._step_index = step_index
else:
step_index = index_candidates[0].item()
self._step_index = step_index
self._step_index = self.step_index_init
def step(
self,
@@ -884,8 +897,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
if self.step_index_init is None:
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
else:
step_indices = [self.step_index_init] * timesteps.shape[0]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):