Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
3d5c9648e6 make qwen hidden states contiguous to make torchao happy. 2026-02-04 18:13:26 +05:30
10 changed files with 51 additions and 160 deletions

View File

@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
from torchao.quantization import Int4WeightOnlyConfig
pipeline_quant_config = PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",

View File

@@ -561,11 +561,11 @@ class QwenDoubleStreamAttnProcessor2_0:
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
# Apply output projections
img_attn_output = attn.to_out[0](img_attn_output)
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
if len(attn.to_out) > 1:
img_attn_output = attn.to_out[1](img_attn_output) # dropout
txt_attn_output = attn.to_add_out(txt_attn_output)
txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
return img_attn_output, txt_attn_output

View File

@@ -545,9 +545,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -867,9 +867,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -245,26 +245,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -272,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -308,12 +287,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
@@ -750,7 +724,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -764,7 +738,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -848,7 +822,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -858,10 +832,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -888,10 +860,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -922,7 +891,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -932,7 +901,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -1045,7 +1014,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -1055,10 +1024,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -1139,9 +1106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
@@ -1251,10 +1216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)

View File

@@ -141,10 +141,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0):
The flow shift factor. Valid only when `use_flow_sigmas=True`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -167,15 +163,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
@@ -184,32 +180,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[Literal["learned", "learned_range"]] = None,
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -217,15 +200,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -244,12 +219,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
@@ -280,11 +250,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
):
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -416,7 +382,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
def _sigma_to_t(self, sigma, log_sigmas):
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -453,7 +419,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
@@ -475,7 +441,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -601,7 +567,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -615,7 +581,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -700,7 +666,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -710,10 +676,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -740,10 +704,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -775,7 +736,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -785,7 +746,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -899,7 +860,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -909,10 +870,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -992,7 +951,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def _init_step_index(self, timestep: Union[int, torch.Tensor]):
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -1016,7 +975,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -1068,10 +1027,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise
@@ -1118,21 +1074,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the clean `original_samples` using the scheduler's equivalent function.
Args:
original_samples (`torch.Tensor`):
The original samples to add noise to.
noise (`torch.Tensor`):
The noise tensor.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -1162,5 +1103,5 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self) -> int:
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1120,9 +1120,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -662,9 +662,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -1122,9 +1122,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -1083,9 +1083,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.