mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 02:14:43 +08:00
Compare commits
5 Commits
pr-tests-f
...
torchao-of
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
03584a8174 | ||
|
|
6f1042e36c | ||
|
|
d5da453de5 | ||
|
|
15370f8412 | ||
|
|
a96b145304 |
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
|
||||
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
|
||||
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
|
||||
| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
|
||||
|
||||
| Flux Fill ControlNet Pipeline | A modified version of the `FluxFillPipeline` and `FluxControlNetInpaintPipeline` that supports Controlnet with Flux Fill model.| [Flux Fill ControlNet Pipeline](#Flux-Fill-ControlNet-Pipeline) | - | [pratim4dasude](https://github.com/pratim4dasude) |
|
||||
|
||||
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
|
||||
|
||||
@@ -5527,3 +5527,106 @@ images = pipe(
|
||||
).images
|
||||
images[0].save("pizzeria.png")
|
||||
```
|
||||
|
||||
# Flux Fill ControlNet Pipeline
|
||||
|
||||
This implementation of Flux Fill + ControlNet Inpaint combines the fill-style masked editing of FLUX.1-Fill-dev with full ControlNet conditioning. The base image is processed through the Fill model while the ControlNet receives the corresponding conditioning input (depth, canny, pose, etc.), and both outputs are fused during denoising to guide structure and composition.
|
||||
|
||||
While FLUX.1-Fill-dev is designed for mask-based edits, it was not originally trained to operate jointly with ControlNet. In practice, this combined setup works well for structured inpainting tasks, though results may vary depending on the conditioning strength and the alignment between the mask and the control input.
|
||||
|
||||
## Example Usage
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import (
|
||||
FluxControlNetModel,
|
||||
FluxPriorReduxPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
# NEW PIPELINE (updated name)
|
||||
from pipline_flux_fill_controlnet_Inpaint import FluxControlNetFillInpaintPipeline
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Models
|
||||
base_model = "black-forest-labs/FLUX.1-Fill-dev"
|
||||
controlnet_model = "Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
|
||||
prior_model = "black-forest-labs/FLUX.1-Redux-dev"
|
||||
|
||||
# Load ControlNet
|
||||
controlnet = FluxControlNetModel.from_pretrained(
|
||||
controlnet_model,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
# Load Fill + ControlNet Pipeline
|
||||
fill_pipe = FluxControlNetFillInpaintPipeline.from_pretrained(
|
||||
base_model,
|
||||
controlnet=controlnet,
|
||||
torch_dtype=dtype,
|
||||
).to(device)
|
||||
|
||||
# OPTIONAL FP8
|
||||
# fill_pipe.transformer.enable_layerwise_casting(
|
||||
# storage_dtype=torch.float8_e4m3fn,
|
||||
# compute_dtype=torch.bfloat16
|
||||
# )
|
||||
|
||||
# OPTIONAL Prior Redux
|
||||
#pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
|
||||
# prior_model,
|
||||
# torch_dtype=dtype,
|
||||
#).to(device)
|
||||
|
||||
# Inputs
|
||||
|
||||
# combined_image = load_image("person_input.png")
|
||||
|
||||
|
||||
# 1. Prior conditioning
|
||||
#prior_out = pipe_prior_redux(
|
||||
# image=cloth_image,
|
||||
# prompt=cloth_prompt,
|
||||
#)
|
||||
|
||||
# 2. Fill Inpaint with ControlNet
|
||||
|
||||
# canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6).
|
||||
|
||||
img = load_image(r"imgs/background.jpg")
|
||||
mask = load_image(r"imgs/mask.png")
|
||||
|
||||
control_image_depth = load_image(r"imgs/dog_depth _2.png")
|
||||
|
||||
result = fill_pipe(
|
||||
prompt="a dog on a bench",
|
||||
image=img,
|
||||
mask_image=mask,
|
||||
|
||||
control_image=control_image_depth,
|
||||
control_mode=[2], # union mode
|
||||
control_guidance_start=0.0,
|
||||
control_guidance_end=0.8,
|
||||
controlnet_conditioning_scale=0.9,
|
||||
|
||||
height=1024,
|
||||
width=1024,
|
||||
|
||||
strength=1.0,
|
||||
guidance_scale=50.0,
|
||||
num_inference_steps=60,
|
||||
max_sequence_length=512,
|
||||
|
||||
# **prior_out,
|
||||
)
|
||||
|
||||
# result.images[0].save("flux_fill_controlnet_inpaint.png")
|
||||
|
||||
from datetime import datetime
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
result.images[0].save(f"flux_fill_controlnet_inpaint_depth{timestamp}.jpg")
|
||||
```
|
||||
|
||||
|
||||
1319
examples/community/pipline_flux_fill_controlnet_Inpaint.py
Normal file
1319
examples/community/pipline_flux_fill_controlnet_Inpaint.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -488,9 +488,20 @@ class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
# Copied from diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
in_sigmas (`torch.Tensor`):
|
||||
The input sigma values to be converted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted sigma values following the Karras noise schedule.
|
||||
"""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
|
||||
@@ -99,15 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
beta_start (`float`, defaults to `0.0001`):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
beta_end (`float`, defaults to `0.02`):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear` or `scaled_linear`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
@@ -118,14 +117,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
prediction_type (`"epsilon"`, `"sample"`, or `"v_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper).
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
steps_offset (`int`, defaults to `0`):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
"""
|
||||
|
||||
@@ -138,13 +137,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
use_exponential_sigmas: Optional[bool] = False,
|
||||
use_beta_sigmas: Optional[bool] = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
@@ -183,7 +182,15 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
def init_noise_sigma(self) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
The standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The standard deviation of the initial noise distribution, computed based on the maximum sigma value and
|
||||
the timestep spacing configuration.
|
||||
"""
|
||||
# standard deviation of the initial noise distribution
|
||||
if self.config.timestep_spacing in ["linspace", "trailing"]:
|
||||
return self.sigmas.max()
|
||||
@@ -191,21 +198,29 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return (self.sigmas.max() ** 2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
The index counter for current timestep. It will increase by 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index for the scheduler, or `None` if not set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -239,14 +254,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
def get_lms_coefficient(self, order: int, t: int, current_order: int) -> float:
|
||||
"""
|
||||
Compute the linear multistep coefficient.
|
||||
|
||||
Args:
|
||||
order ():
|
||||
t ():
|
||||
current_order ():
|
||||
order (`int`):
|
||||
The order of the linear multistep method.
|
||||
t (`int`):
|
||||
The current timestep index.
|
||||
current_order (`int`):
|
||||
The current order for which to compute the coefficient.
|
||||
|
||||
Returns:
|
||||
`float`:
|
||||
The computed linear multistep coefficient.
|
||||
"""
|
||||
|
||||
def lms_derivative(tau):
|
||||
@@ -261,7 +283,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -367,7 +389,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -403,9 +425,19 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
# copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
in_sigmas (`torch.Tensor`):
|
||||
The input sigma values to be converted.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted sigma values following the Karras noise schedule.
|
||||
"""
|
||||
|
||||
sigma_min: float = in_sigmas[-1].item()
|
||||
sigma_max: float = in_sigmas[0].item()
|
||||
@@ -629,5 +661,5 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -79,15 +79,14 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
num_train_timesteps (`int`, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
beta_start (`float`, defaults to 0.0001):
|
||||
beta_start (`float`, defaults to `0.0001`):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
beta_end (`float`, defaults to `0.02`):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
skip_prk_steps (`bool`, defaults to `False`):
|
||||
@@ -97,14 +96,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
|
||||
there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
otherwise it uses the alpha value at step 0.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
prediction_type (`"epsilon"` or `"v_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process)
|
||||
or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf)
|
||||
paper).
|
||||
timestep_spacing (`str`, defaults to `"leading"`):
|
||||
or `v_prediction` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"leading"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
steps_offset (`int`, defaults to `0`):
|
||||
An offset added to the inference steps, as required by some model families.
|
||||
"""
|
||||
|
||||
@@ -117,12 +115,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
skip_prk_steps: bool = False,
|
||||
set_alpha_to_one: bool = False,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "leading",
|
||||
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
@@ -164,7 +162,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.plms_timesteps = None
|
||||
self.timesteps = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -243,7 +241,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
@@ -276,14 +274,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -335,14 +332,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
@@ -403,19 +399,27 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://huggingface.co/papers/2202.09778
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
# Note that x_t needs to be added to both sides of the equation
|
||||
def _get_prev_sample(
|
||||
self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the previous sample x_(t-δ) from the current sample x_t using formula (9) from the [PNDM
|
||||
paper](https://huggingface.co/papers/2202.09778).
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# alpha_prod_t -> α_t
|
||||
# alpha_prod_t_prev -> α_(t−δ)
|
||||
# beta_prod_t -> (1 - α_t)
|
||||
# beta_prod_t_prev -> (1 - α_(t−δ))
|
||||
# sample -> x_t
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The current sample x_t.
|
||||
timestep (`int`):
|
||||
The current timestep t.
|
||||
prev_timestep (`int`):
|
||||
The previous timestep (t-δ).
|
||||
model_output (`torch.Tensor`):
|
||||
The model output e_θ(x_t, t).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The previous sample x_(t-δ).
|
||||
"""
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
@@ -489,5 +493,5 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -73,7 +73,7 @@ if is_torchao_available():
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
class TorchAoConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
@@ -131,7 +131,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
class TorchAoTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
@@ -540,7 +540,7 @@ class TorchAoTest(unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
class TorchAoSerializationTest(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@@ -651,23 +651,22 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
return PipelineQuantizationConfig(
|
||||
quant_mapping={
|
||||
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
|
||||
},
|
||||
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())},
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
|
||||
"when compiling."
|
||||
)
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
pipe = self._init_pipeline(self.quantization_config, torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
# No compilation because it fails with:
|
||||
# RuntimeError: _apply(): Couldn't swap Linear.weight
|
||||
super().test_torch_compile_with_cpu_offload()
|
||||
|
||||
# small resolutions to ensure speedy execution.
|
||||
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
|
||||
|
||||
@parameterized.expand([False, True])
|
||||
@unittest.skip(
|
||||
@@ -698,7 +697,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoTests(unittest.TestCase):
|
||||
@@ -857,7 +856,7 @@ class SlowTorchAoTests(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
@require_torchao_version_greater_or_equal("0.14.0")
|
||||
@slow
|
||||
@nightly
|
||||
class SlowTorchAoPreserializedModelTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user