mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-25 04:55:53 +08:00
Compare commits
9 Commits
modular-te
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8e88ff2ce | ||
|
|
6e24cd842c | ||
|
|
981eb802c6 | ||
|
|
1eb40c6dbd | ||
|
|
bff672f47f | ||
|
|
d4f97d1921 | ||
|
|
1d32b19ad4 | ||
|
|
699297f647 | ||
|
|
7a02fadad3 |
@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
|
||||
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
|
||||
# Issue: https://github.com/huggingface/diffusers/issues/12975
|
||||
if self.training:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel(
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Longcat-Image.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -22,6 +22,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDIMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
||||
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDIMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "trailing"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||
`use_clipped_model_output` has no effect.
|
||||
variance_noise (`torch.Tensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
|
||||
`tuple`.
|
||||
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 1. get previous step value (=t+1)
|
||||
prev_timestep = timestep
|
||||
timestep = min(
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps,
|
||||
self.config.num_train_timesteps - 1,
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (prev_sample, pred_original_sample)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep=None):
|
||||
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
|
||||
if prev_timestep is None:
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def _batch_get_variance(self, t, prev_t):
|
||||
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
|
||||
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
|
||||
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
|
||||
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
|
||||
correction is necessary because the predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
|
||||
the input and `use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
|
||||
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if variance_noise is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
variance = std_dev_t * variance_noise
|
||||
|
||||
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
timesteps (`torch.Tensor`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDPMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
|
||||
|
||||
|
||||
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDPMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
shift_terminal: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||
if shift_terminal is not None and not use_flow_sigmas:
|
||||
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
@@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
|
||||
) -> None:
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
||||
automatically.
|
||||
mu (`float`, *optional*):
|
||||
Optional mu parameter for dynamic shifting when using exponential time shift type.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is not None:
|
||||
if not self.config.use_flow_sigmas:
|
||||
raise ValueError(
|
||||
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
|
||||
"Please set `use_flow_sigmas=True` during scheduler initialization."
|
||||
)
|
||||
num_inference_steps = len(sigmas)
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
@@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_beta_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_flow_sigmas:
|
||||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
eps = 1e-6
|
||||
if np.fabs(sigmas[0] - 1) < eps:
|
||||
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
|
||||
sigmas[0] -= eps
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
@@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
else:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -37,14 +37,9 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -68,21 +63,10 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"max_sequence_length",
|
||||
]
|
||||
)
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = super().get_pipeline(components_manager, torch_dtype)
|
||||
|
||||
@@ -145,13 +129,9 @@ class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
default_repo_id = "black-forest-labs/FLUX.1-kontext-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["latents"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -32,14 +32,9 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
default_repo_id = "black-forest-labs/FLUX.2-dev"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
@@ -68,10 +63,6 @@ class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
|
||||
@@ -34,16 +34,10 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -66,16 +60,10 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image", "height", "width"])
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
generator = self.get_generator()
|
||||
inputs = {
|
||||
@@ -98,7 +86,6 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
default_repo_id = "Qwen/Qwen-Image-Edit-2509"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
|
||||
@@ -279,8 +279,6 @@ class TestSDXLModularPipelineFast(
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
default_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -293,11 +291,6 @@ class TestSDXLModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = None # None if vae_encoder is not supported
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
@@ -346,11 +339,6 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image"])
|
||||
expected_image_output_shape = (1, 3, 64, 64)
|
||||
|
||||
# should choose from the dict returned by `get_dummy_inputs`
|
||||
text_encoder_block_params = frozenset(["prompt"])
|
||||
decode_block_params = frozenset(["output_type"])
|
||||
vae_encoder_block_params = frozenset(["image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
|
||||
@@ -48,12 +48,6 @@ class ModularPipelineTesterMixin:
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_repo_id(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `default_repo_id` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
def pipeline_blocks_class(self) -> Union[Callable, ModularPipelineBlocks]:
|
||||
raise NotImplementedError(
|
||||
@@ -96,30 +90,6 @@ class ModularPipelineTesterMixin:
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def text_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `text_encoder_block_params` in the child test class. "
|
||||
"`text_encoder_block_params` are the parameters required to be passed to the text encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def decode_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `decode_block_params` in the child test class. "
|
||||
"`decode_block_params` are the parameters required to be passed to the decode block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def vae_encoder_block_params(self) -> frozenset:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
|
||||
"`vae_encoder_block_params` are the parameters required to be passed to the vae encoder block. "
|
||||
" if should be a subset of the parameters returned by `get_dummy_inputs`"
|
||||
"See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
def setup_method(self):
|
||||
# clean up the VRAM before each test
|
||||
torch.compiler.reset()
|
||||
@@ -154,96 +124,6 @@ class ModularPipelineTesterMixin:
|
||||
_check_for_parameters(self.params, input_parameters, "input")
|
||||
_check_for_parameters(self.optional_params, optional_parameters, "optional")
|
||||
|
||||
def test_loading_from_default_repo(self):
|
||||
if self.default_repo_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
pipe = ModularPipeline.from_pretrained(self.default_repo_id)
|
||||
assert pipe.blocks.__class__ == self.pipeline_blocks_class
|
||||
except Exception as e:
|
||||
assert False, f"Failed to load pipeline from default repo: {e}"
|
||||
|
||||
def test_modular_inference(self):
|
||||
# run the pipeline to get the base output for comparison
|
||||
pipe = self.get_pipeline()
|
||||
pipe.to(torch_device, torch.float32)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
standard_output = pipe(**inputs, output="images")
|
||||
|
||||
# create text, denoise, decoder (and optional vae encoder) nodes
|
||||
blocks = self.pipeline_blocks_class()
|
||||
|
||||
assert "text_encoder" in blocks.sub_blocks, "`text_encoder` block is not present in the pipeline"
|
||||
assert "denoise" in blocks.sub_blocks, "`denoise` block is not present in the pipeline"
|
||||
assert "decode" in blocks.sub_blocks, "`decode` block is not present in the pipeline"
|
||||
if self.vae_encoder_block_params is not None:
|
||||
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
|
||||
|
||||
# manually set the components in the sub_pipe
|
||||
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
|
||||
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
|
||||
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
|
||||
for n, comp in pipe.components.items():
|
||||
if not hasattr(sub_pipe, n):
|
||||
setattr(sub_pipe, n, comp)
|
||||
|
||||
text_node = blocks.sub_blocks["text_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
text_node.load_components(torch_dtype=torch.float32)
|
||||
text_node.to(torch_device)
|
||||
manually_set_all_components(pipe, text_node)
|
||||
|
||||
denoise_node = blocks.sub_blocks["denoise"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
denoise_node.load_components(torch_dtype=torch.float32)
|
||||
denoise_node.to(torch_device)
|
||||
manually_set_all_components(pipe, denoise_node)
|
||||
|
||||
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
decoder_node.load_components(torch_dtype=torch.float32)
|
||||
decoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, decoder_node)
|
||||
|
||||
if self.vae_encoder_block_params is not None:
|
||||
vae_encoder_node = blocks.sub_blocks["vae_encoder"].init_pipeline(self.pretrained_model_name_or_path)
|
||||
vae_encoder_node.load_components(torch_dtype=torch.float32)
|
||||
vae_encoder_node.to(torch_device)
|
||||
manually_set_all_components(pipe, vae_encoder_node)
|
||||
else:
|
||||
vae_encoder_node = None
|
||||
|
||||
# prepare inputs for each node
|
||||
inputs = self.get_dummy_inputs()
|
||||
|
||||
def get_block_inputs(inputs: dict, block_params: frozenset) -> tuple[dict, dict]:
|
||||
block_inputs = {}
|
||||
for name in block_params:
|
||||
if name in inputs:
|
||||
block_inputs[name] = inputs.pop(name)
|
||||
return block_inputs, inputs
|
||||
|
||||
text_inputs, inputs = get_block_inputs(inputs, self.text_encoder_block_params)
|
||||
decoder_inputs, inputs = get_block_inputs(inputs, self.decode_block_params)
|
||||
if vae_encoder_node is not None:
|
||||
vae_encoder_inputs, inputs = get_block_inputs(inputs, self.vae_encoder_block_params)
|
||||
|
||||
# this is also to make sure pipelines mark text outputs as denoiser_input_fields
|
||||
text_output = text_node(**text_inputs).get_by_kwargs("denoiser_input_fields")
|
||||
if vae_encoder_node is not None:
|
||||
vae_encoder_output = vae_encoder_node(**vae_encoder_inputs).values
|
||||
denoise_inputs = {**text_output, **vae_encoder_output, **inputs}
|
||||
else:
|
||||
denoise_inputs = {**text_output, **inputs}
|
||||
|
||||
# denoise node output should be "latents"
|
||||
latents = denoise_node(**denoise_inputs).latents
|
||||
# denoder node input should be "latents" and output should be "images"
|
||||
modular_output = decoder_node(**decoder_inputs, latents=latents).images
|
||||
|
||||
assert modular_output.shape == standard_output.shape, (
|
||||
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
|
||||
pipe = self.get_pipeline().to(torch_device)
|
||||
|
||||
|
||||
@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=8e-4)
|
||||
|
||||
@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user