mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
3 Commits
v0.26.0-re
...
fix-schedu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df493b90ac | ||
|
|
c2a8afc60c | ||
|
|
119cf05bfa |
@@ -551,6 +551,7 @@ class StableDiffusionImg2ImgPipeline(
|
|||||||
|
|
||||||
t_start = max(num_inference_steps - init_timestep, 0)
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||||
|
self.scheduler._step_index_init = t_start * self.scheduler.order
|
||||||
|
|
||||||
return timesteps, num_inference_steps - t_start
|
return timesteps, num_inference_steps - t_start
|
||||||
|
|
||||||
|
|||||||
@@ -215,6 +215,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self.model_outputs = [None] * solver_order
|
self.model_outputs = [None] * solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
self._step_index = None
|
self._step_index = None
|
||||||
|
self._step_index_init = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self):
|
||||||
@@ -223,6 +224,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
"""
|
"""
|
||||||
return self._step_index
|
return self._step_index
|
||||||
|
|
||||||
|
@property
|
||||||
|
def step_index_init(self):
|
||||||
|
"""
|
||||||
|
the first step_index for denoising loop.
|
||||||
|
"""
|
||||||
|
return self._step_index_init
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
@@ -760,6 +768,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return x_t
|
return x_t
|
||||||
|
|
||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep):
|
||||||
|
|
||||||
|
if self.step_index_init is None:
|
||||||
if isinstance(timestep, torch.Tensor):
|
if isinstance(timestep, torch.Tensor):
|
||||||
timestep = timestep.to(self.timesteps.device)
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
|
||||||
@@ -776,7 +786,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
step_index = index_candidates[0].item()
|
step_index = index_candidates[0].item()
|
||||||
|
|
||||||
|
self._step_index_init = step_index
|
||||||
self._step_index = step_index
|
self._step_index = step_index
|
||||||
|
else:
|
||||||
|
self._step_index = self.step_index_init
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
@@ -884,8 +897,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
else:
|
else:
|
||||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
if self.step_index_init is None:
|
||||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
else:
|
||||||
|
step_indices = [self.step_index_init] * timesteps.shape[0]
|
||||||
|
|
||||||
sigma = sigmas[step_indices].flatten()
|
sigma = sigmas[step_indices].flatten()
|
||||||
while len(sigma.shape) < len(original_samples.shape):
|
while len(sigma.shape) < len(original_samples.shape):
|
||||||
|
|||||||
Reference in New Issue
Block a user