Compare commits

..

9 Commits

Author SHA1 Message Date
David El Malih
e8e88ff2ce Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)
docs: improve docstring scheduling_ddpm_flax.py
2026-01-23 11:51:47 -08:00
David El Malih
6e24cd842c Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)
* docs: improve docstring scheduling_ddim_parallel.py

* docs: improve docstring scheduling_ddim_parallel.py

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* fix style

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-01-23 10:00:32 -08:00
Garry Ling
981eb802c6 feat: add qkv projection fuse for longcat transformers (#13021)
feat: add qkv fuse for longcat transformers

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 23:02:03 +05:30
jiqing-feng
1eb40c6dbd Resnet only use contiguous in training mode. (#12977)
* fix contiguous

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* bigger tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update tol

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-23 18:40:10 +05:30
Sayak Paul
bff672f47f fix Dockerfiles for cuda and xformers. (#13022) 2026-01-23 16:45:14 +05:30
David El Malih
d4f97d1921 Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)
docs: improve docstring scheduling_ddim_inverse.py
2026-01-22 15:42:45 -08:00
David El Malih
1d32b19ad4 Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)
* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py
2026-01-22 09:11:14 -08:00
Garry Ling
699297f647 feat: accelerate longcat-image with regional compile (#13019) 2026-01-22 20:21:45 +05:30
Aryan V S
7a02fadad3 [scheduler] Support custom sigmas in UniPCMultistepScheduler (#12109)
* update

* fix tests

* Apply suggestions from code review

* Revert default flow sigmas change so that tests relying on UniPC multistep still pass

* Remove custom timesteps for UniPC multistep set_timesteps

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: Daniel Gu <dgu8957@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-01-21 17:18:59 -08:00
12 changed files with 187 additions and 72 deletions

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.12
ARG PYTHON_VERSION=3.11
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor.contiguous())
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
# Issue: https://github.com/huggingface/diffusers/issues/12975
if self.training:
input_tensor = input_tensor.contiguous()
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

View File

@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
@@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel(
PeftAdapterMixin,
FromOriginalModelMixin,
CacheMixin,
AttentionMixin,
):
"""
The Transformer model introduced in Longcat-Image.
"""
_supports_gradient_checkpointing = True
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
@register_to_config
def __init__(

View File

@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self._blocks = blocks
self.blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
@@ -1603,9 +1603,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1614,7 +1612,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self._blocks.inputs:
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
return params
@@ -1777,15 +1775,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Returns:
- The docstring of the pipeline blocks
"""
return self._blocks.doc
@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)
return self.blocks.doc
def register_components(self, **kwargs):
"""
@@ -2519,7 +2509,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -2573,7 +2563,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self._blocks.inputs:
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
@@ -2592,9 +2582,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Run the pipeline
with torch.no_grad():
try:
_, state = self._blocks(self, state)
_, state = self.blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise

View File

@@ -22,6 +22,7 @@ import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDIMSchedulerState:
common: CommonSchedulerState
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDIMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
alpha_prod_t = state.common.alphas_cumprod[timestep]
alpha_prod_t_prev = jnp.where(
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
prev_timestep >= 0,
state.common.alphas_cumprod[prev_timestep],
state.final_alpha_cumprod,
)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev

View File

@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
clip_sample_range: float = 1.0,
timestep_spacing: str = "leading",
timestep_spacing: Literal["leading", "trailing"] = "leading",
rescale_betas_zero_snr: bool = False,
**kwargs,
):
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
eta (`float`):
The weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
`use_clipped_model_output` has no effect.
variance_noise (`torch.Tensor`):
Alternative to generating noise with `generator` by directly providing the noise for the variance
itself. Useful for methods such as [`CycleDiffusion`].
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
`tuple`.
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
# 1. get previous step value (=t+1)
prev_timestep = timestep
timestep = min(
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
timestep - self.config.num_train_timesteps // self.num_inference_steps,
self.config.num_train_timesteps - 1,
)
# 2. compute alphas, betas
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
return (prev_sample, pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
def rescale_zero_terminal_snr(betas):
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
"""
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_variance(self, timestep, prev_timestep=None):
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
if prev_timestep is None:
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return variance
def _batch_get_variance(self, t, prev_t):
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
generator: Optional[torch.Generator] = None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
use_clipped_model_output (`bool`, defaults to `False`):
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
correction is necessary because the predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
the input and `use_clipped_model_output` has no effect.
generator (`torch.Generator`, *optional*):
Random number generator.
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
variance = std_dev_t * variance_noise
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
def batch_step_no_noise(
self,
model_output: torch.Tensor,
timesteps: List[int],
timesteps: torch.Tensor,
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`): direct output from learned diffusion model.
timesteps (`List[int]`):
timesteps (`torch.Tensor`):
current discrete timesteps in the diffusion chain. This is now a list of integers.
sample (`torch.Tensor`):
current instance of sample being created by diffusion process.
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -22,6 +22,7 @@ import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import logging
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
)
logger = logging.get_logger(__name__)
@flax.struct.dataclass
class DDPMSchedulerState:
common: CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
num_inference_steps: Optional[int] = None
@classmethod
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
def create(
cls,
common: CommonSchedulerState,
init_noise_sigma: jnp.ndarray,
timesteps: jnp.ndarray,
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
prediction_type: str = "epsilon",
dtype: jnp.dtype = jnp.float32,
):
logger.warning(
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
)
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
)
def scale_model_input(
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
self,
state: DDPMSchedulerState,
sample: jnp.ndarray,
timestep: Optional[int] = None,
) -> jnp.ndarray:
"""
Args:

View File

@@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
time_shift_type: Literal["exponential"] = "exponential",
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
shift_terminal: Optional[float] = None,
) -> None:
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
if shift_terminal is not None and not use_flow_sigmas:
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
if rescale_betas_zero_snr:
self.betas = rescale_zero_terminal_snr(self.betas)
@@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
def set_timesteps(
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
) -> None:
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`List[float]`, *optional*):
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
automatically.
mu (`float`, *optional*):
Optional mu parameter for dynamic shifting when using exponential time shift type.
"""
if self.config.use_dynamic_shifting and mu is None:
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
if sigmas is not None:
if not self.config.use_flow_sigmas:
raise ValueError(
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
"Please set `use_flow_sigmas=True` during scheduler initialization."
)
num_inference_steps = len(sigmas)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if mu is not None:
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.use_karras_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_exponential_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_beta_sigmas:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
sigmas = np.flip(sigmas).copy()
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
elif self.config.use_flow_sigmas:
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
if sigmas is None:
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
if self.config.shift_terminal:
sigmas = self.stretch_shift_to_terminal(sigmas)
eps = 1e-6
if np.fabs(sigmas[0] - 1) < eps:
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
sigmas[0] -= eps
timesteps = (sigmas * self.config.num_train_timesteps).copy()
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
@@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
)
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
else:
if sigmas is None:
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
r"""
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
value.
Reference:
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
Args:
t (`torch.Tensor`):
A tensor of timesteps to be stretched and shifted.
Returns:
`torch.Tensor`:
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
"""
one_minus_z = 1 - t
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
stretched_t = 1 - (one_minus_z / scale_factor)
return stretched_t
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
def _time_shift_exponential(self, mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
def _time_shift_linear(self, mu, sigma, t):
return mu / (mu + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
"""

View File

@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=5e-1)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@is_flaky()
def test_model_cpu_offload_forward_pass(self):
super().test_inference_batch_single_identical(expected_max_diff=8e-4)

View File

@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_save_load_dduf(self):
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
@slow
@require_torch_accelerator