mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
3 Commits
controlnet
...
fix-schedu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df493b90ac | ||
|
|
c2a8afc60c | ||
|
|
119cf05bfa |
@@ -551,6 +551,7 @@ class StableDiffusionImg2ImgPipeline(
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
self.scheduler._step_index_init = t_start * self.scheduler.order
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
@@ -215,6 +215,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.model_outputs = [None] * solver_order
|
||||
self.lower_order_nums = 0
|
||||
self._step_index = None
|
||||
self._step_index_init = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
@@ -222,6 +223,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The index counter for current timestep. It will increae 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def step_index_init(self):
|
||||
"""
|
||||
the first step_index for denoising loop.
|
||||
"""
|
||||
return self._step_index_init
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
@@ -760,23 +768,28 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
if self.step_index_init is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
index_candidates = (self.timesteps == timestep).nonzero()
|
||||
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
if len(index_candidates) == 0:
|
||||
step_index = len(self.timesteps) - 1
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
elif len(index_candidates) > 1:
|
||||
step_index = index_candidates[1].item()
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index_init = step_index
|
||||
self._step_index = step_index
|
||||
else:
|
||||
step_index = index_candidates[0].item()
|
||||
|
||||
self._step_index = step_index
|
||||
self._step_index = self.step_index_init
|
||||
|
||||
def step(
|
||||
self,
|
||||
@@ -884,8 +897,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
else:
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
if self.step_index_init is None:
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
else:
|
||||
step_indices = [self.step_index_init] * timesteps.shape[0]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
||||
Reference in New Issue
Block a user