mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-25 04:55:53 +08:00
Compare commits
13 Commits
modular-up
...
modular-lo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e77f1c56dc | ||
|
|
372222a4b6 | ||
|
|
1f57b175ae | ||
|
|
581a425130 | ||
|
|
e8e88ff2ce | ||
|
|
6e24cd842c | ||
|
|
981eb802c6 | ||
|
|
1eb40c6dbd | ||
|
|
bff672f47f | ||
|
|
d4f97d1921 | ||
|
|
1d32b19ad4 | ||
|
|
699297f647 | ||
|
|
7a02fadad3 |
@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG PYTHON_VERSION=3.11
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
RUN apt-get -y update \
|
||||
@@ -32,10 +32,12 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
|
||||
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
|
||||
|
||||
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
|
||||
# Install torch, torchvision, and torchaudio together to ensure compatibility
|
||||
RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Optional, Union
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
|
||||
@@ -220,4 +220,11 @@ class AutoModel(ConfigMixin):
|
||||
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
||||
|
||||
kwargs = {**load_config_kwargs, **kwargs}
|
||||
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
model = model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
||||
|
||||
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
|
||||
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
|
||||
load_id = "|".join("null" if p is None else p for p in parts)
|
||||
model._diffusers_load_id = load_id
|
||||
|
||||
return model
|
||||
|
||||
@@ -366,7 +366,12 @@ class ResnetBlock2D(nn.Module):
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor.contiguous())
|
||||
# Only use contiguous() during training to avoid DDP gradient stride mismatch warning.
|
||||
# In inference mode (eval or no_grad), skip contiguous() for better performance, especially on CPU.
|
||||
# Issue: https://github.com/huggingface/diffusers/issues/12975
|
||||
if self.training:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
|
||||
@@ -400,12 +400,14 @@ class LongCatImageTransformer2DModel(
|
||||
PeftAdapterMixin,
|
||||
FromOriginalModelMixin,
|
||||
CacheMixin,
|
||||
AttentionMixin,
|
||||
):
|
||||
"""
|
||||
The Transformer model introduced in Longcat-Image.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
|
||||
self._blocks = blocks
|
||||
self.blocks = blocks
|
||||
self._components_manager = components_manager
|
||||
self._collection = collection
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
|
||||
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
|
||||
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
|
||||
|
||||
# update component_specs and config_specs based on modular_model_index.json
|
||||
if modular_config_dict is not None:
|
||||
@@ -1603,9 +1603,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
self.register_to_config(
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> Dict[str, Any]:
|
||||
@@ -1614,7 +1612,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- Dictionary mapping input names to their default values
|
||||
"""
|
||||
params = {}
|
||||
for input_param in self._blocks.inputs:
|
||||
for input_param in self.blocks.inputs:
|
||||
params[input_param.name] = input_param.default
|
||||
return params
|
||||
|
||||
@@ -1777,15 +1775,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Returns:
|
||||
- The docstring of the pipeline blocks
|
||||
"""
|
||||
return self._blocks.doc
|
||||
|
||||
@property
|
||||
def blocks(self) -> ModularPipelineBlocks:
|
||||
"""
|
||||
Returns:
|
||||
- A copy of the pipeline blocks
|
||||
"""
|
||||
return deepcopy(self._blocks)
|
||||
return self.blocks.doc
|
||||
|
||||
def register_components(self, **kwargs):
|
||||
"""
|
||||
@@ -2152,6 +2142,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
name
|
||||
for name in self._component_specs.keys()
|
||||
if self._component_specs[name].default_creation_method == "from_pretrained"
|
||||
and self._component_specs[name].pretrained_model_name_or_path is not None
|
||||
and getattr(self, name, None) is None
|
||||
]
|
||||
elif isinstance(names, str):
|
||||
names = [names]
|
||||
@@ -2519,7 +2511,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
|
||||
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
|
||||
if hasattr(sub_block, "set_progress_bar_config"):
|
||||
sub_block.set_progress_bar_config(**kwargs)
|
||||
|
||||
@@ -2573,7 +2565,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# Add inputs to state, using defaults if not provided in the kwargs or the state
|
||||
# if same input already in the state, will override it if provided in the kwargs
|
||||
for expected_input_param in self._blocks.inputs:
|
||||
for expected_input_param in self.blocks.inputs:
|
||||
name = expected_input_param.name
|
||||
default = expected_input_param.default
|
||||
kwargs_type = expected_input_param.kwargs_type
|
||||
@@ -2592,9 +2584,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
# Run the pipeline
|
||||
with torch.no_grad():
|
||||
try:
|
||||
_, state = self._blocks(self, state)
|
||||
_, state = self.blocks(self, state)
|
||||
except Exception:
|
||||
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
|
||||
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
import inspect
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -185,7 +185,7 @@ class ComponentSpec:
|
||||
"""
|
||||
Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
|
||||
"""
|
||||
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
|
||||
return DIFFUSERS_LOAD_ID_FIELDS.copy()
|
||||
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
@@ -197,7 +197,7 @@ class ComponentSpec:
|
||||
return "null"
|
||||
parts = [getattr(self, k) for k in self.loading_fields()]
|
||||
parts = ["null" if p is None else p for p in parts]
|
||||
return "|".join(p for p in parts if p)
|
||||
return "|".join(parts)
|
||||
|
||||
@classmethod
|
||||
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
|
||||
|
||||
@@ -22,6 +22,7 @@ import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDIMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -125,6 +129,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
|
||||
@@ -152,7 +160,10 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDIMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
@@ -190,7 +201,9 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
|
||||
alpha_prod_t = state.common.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = jnp.where(
|
||||
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
|
||||
prev_timestep >= 0,
|
||||
state.common.alphas_cumprod[prev_timestep],
|
||||
state.final_alpha_cumprod,
|
||||
)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
@@ -99,7 +99,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -187,14 +187,14 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
clip_sample_range: float = 1.0,
|
||||
timestep_spacing: str = "leading",
|
||||
timestep_spacing: Literal["leading", "trailing"] = "leading",
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -210,7 +210,15 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -256,7 +264,11 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -308,20 +320,10 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`int`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
eta (`float`):
|
||||
The weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary
|
||||
because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no
|
||||
clipping has happened, "corrected" `model_output` would coincide with the one provided as input and
|
||||
`use_clipped_model_output` has no effect.
|
||||
variance_noise (`torch.Tensor`):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`CycleDiffusion`].
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_ddim_inverse.DDIMInverseSchedulerOutput`] or
|
||||
`tuple`.
|
||||
@@ -335,7 +337,8 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 1. get previous step value (=t+1)
|
||||
prev_timestep = timestep
|
||||
timestep = min(
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
|
||||
timestep - self.config.num_train_timesteps // self.num_inference_steps,
|
||||
self.config.num_train_timesteps - 1,
|
||||
)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
@@ -378,5 +381,5 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (prev_sample, pred_original_sample)
|
||||
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -101,7 +101,7 @@ def betas_for_alpha_bar(
|
||||
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
def rescale_zero_terminal_snr(betas: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
|
||||
|
||||
@@ -266,7 +266,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep=None):
|
||||
def _get_variance(self, timestep: int, prev_timestep: Optional[int] = None) -> torch.Tensor:
|
||||
if prev_timestep is None:
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
|
||||
@@ -279,7 +279,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return variance
|
||||
|
||||
def _batch_get_variance(self, t, prev_t):
|
||||
def _batch_get_variance(self, t: torch.Tensor, prev_t: torch.Tensor) -> torch.Tensor:
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[torch.clip(prev_t, min=0)]
|
||||
alpha_prod_t_prev[prev_t < 0] = torch.tensor(1.0)
|
||||
@@ -335,7 +335,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddim.DDIMScheduler.set_timesteps
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -392,7 +392,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[DDIMParallelSchedulerOutput, Tuple]:
|
||||
@@ -406,11 +406,13 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
use_clipped_model_output (`bool`, defaults to `False`):
|
||||
If `True`, compute "corrected" `model_output` from the clipped predicted original sample. This
|
||||
correction is necessary because the predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping occurred, the "corrected" `model_output` matches
|
||||
the input and `use_clipped_model_output` has no effect.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random number generator.
|
||||
variance_noise (`torch.Tensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://huggingface.co/papers/2210.05559)
|
||||
@@ -496,7 +498,10 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if variance_noise is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
variance = std_dev_t * variance_noise
|
||||
|
||||
@@ -513,7 +518,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
def batch_step_no_noise(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timesteps: List[int],
|
||||
timesteps: torch.Tensor,
|
||||
sample: torch.Tensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
@@ -528,7 +533,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`): direct output from learned diffusion model.
|
||||
timesteps (`List[int]`):
|
||||
timesteps (`torch.Tensor`):
|
||||
current discrete timesteps in the diffusion chain. This is now a list of integers.
|
||||
sample (`torch.Tensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
@@ -696,5 +701,5 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
|
||||
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return velocity
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -22,6 +22,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import logging
|
||||
from .scheduling_utils_flax import (
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
@@ -32,6 +33,9 @@ from .scheduling_utils_flax import (
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@flax.struct.dataclass
|
||||
class DDPMSchedulerState:
|
||||
common: CommonSchedulerState
|
||||
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
|
||||
num_inference_steps: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
|
||||
def create(
|
||||
cls,
|
||||
common: CommonSchedulerState,
|
||||
init_noise_sigma: jnp.ndarray,
|
||||
timesteps: jnp.ndarray,
|
||||
):
|
||||
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
|
||||
|
||||
|
||||
@@ -105,6 +114,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
prediction_type: str = "epsilon",
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
):
|
||||
logger.warning(
|
||||
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
|
||||
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
|
||||
)
|
||||
self.dtype = dtype
|
||||
|
||||
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
|
||||
@@ -123,7 +136,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
def scale_model_input(
|
||||
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
|
||||
self,
|
||||
state: DDPMSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
timestep: Optional[int] = None,
|
||||
) -> jnp.ndarray:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -226,6 +226,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
shift_terminal: Optional[float] = None,
|
||||
) -> None:
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -245,6 +246,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
|
||||
if shift_terminal is not None and not use_flow_sigmas:
|
||||
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
|
||||
|
||||
if rescale_betas_zero_snr:
|
||||
self.betas = rescale_zero_terminal_snr(self.betas)
|
||||
@@ -313,8 +316,12 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
|
||||
) -> None:
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
mu: Optional[float] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -323,13 +330,24 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
|
||||
automatically.
|
||||
mu (`float`, *optional*):
|
||||
Optional mu parameter for dynamic shifting when using exponential time shift type.
|
||||
"""
|
||||
if self.config.use_dynamic_shifting and mu is None:
|
||||
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
|
||||
|
||||
if sigmas is not None:
|
||||
if not self.config.use_flow_sigmas:
|
||||
raise ValueError(
|
||||
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
|
||||
"Please set `use_flow_sigmas=True` during scheduler initialization."
|
||||
)
|
||||
num_inference_steps = len(sigmas)
|
||||
|
||||
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
|
||||
if mu is not None:
|
||||
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
|
||||
self.config.flow_shift = np.exp(mu)
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = (
|
||||
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
|
||||
@@ -354,8 +372,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
||||
)
|
||||
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
if self.config.use_karras_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -375,6 +394,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_exponential_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -389,6 +410,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_beta_sigmas:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
log_sigmas = np.log(sigmas)
|
||||
sigmas = np.flip(sigmas).copy()
|
||||
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
||||
@@ -403,9 +426,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
elif self.config.use_flow_sigmas:
|
||||
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
|
||||
sigmas = 1.0 - alphas
|
||||
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
|
||||
if sigmas is None:
|
||||
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
|
||||
if self.config.use_dynamic_shifting:
|
||||
sigmas = self.time_shift(mu, 1.0, sigmas)
|
||||
else:
|
||||
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
|
||||
if self.config.shift_terminal:
|
||||
sigmas = self.stretch_shift_to_terminal(sigmas)
|
||||
eps = 1e-6
|
||||
if np.fabs(sigmas[0] - 1) < eps:
|
||||
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
|
||||
sigmas[0] -= eps
|
||||
timesteps = (sigmas * self.config.num_train_timesteps).copy()
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = sigmas[-1]
|
||||
@@ -417,6 +449,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
||||
else:
|
||||
if sigmas is None:
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||
if self.config.final_sigmas_type == "sigma_min":
|
||||
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
||||
@@ -446,6 +480,43 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = None
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
|
||||
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
|
||||
if self.config.time_shift_type == "exponential":
|
||||
return self._time_shift_exponential(mu, sigma, t)
|
||||
elif self.config.time_shift_type == "linear":
|
||||
return self._time_shift_linear(mu, sigma, t)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
|
||||
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
|
||||
value.
|
||||
|
||||
Reference:
|
||||
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
|
||||
|
||||
Args:
|
||||
t (`torch.Tensor`):
|
||||
A tensor of timesteps to be stretched and shifted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
|
||||
def _time_shift_exponential(self, mu, sigma, t):
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
|
||||
def _time_shift_linear(self, mu, sigma, t):
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
@@ -23,6 +23,7 @@ from .constants import (
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
|
||||
@@ -73,3 +73,11 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoint
|
||||
ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
ENCODE_ENDPOINT_FLUX = "https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/"
|
||||
|
||||
|
||||
DIFFUSERS_LOAD_ID_FIELDS = [
|
||||
"pretrained_model_name_or_path",
|
||||
"subfolder",
|
||||
"variant",
|
||||
"revision",
|
||||
]
|
||||
|
||||
@@ -248,6 +248,9 @@ class KandinskyV22InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCas
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(expected_max_diff=5e-1)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
@is_flaky()
|
||||
def test_model_cpu_offload_forward_pass(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=8e-4)
|
||||
|
||||
@@ -191,6 +191,9 @@ class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
|
||||
def test_inference_batch_single_identical(self):
|
||||
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
|
||||
|
||||
def test_save_load_dduf(self):
|
||||
super().test_save_load_dduf(atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_accelerator
|
||||
|
||||
Reference in New Issue
Block a user