mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-06 19:05:11 +08:00
Compare commits
15 Commits
sayakpaul-
...
ltx2-impro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1531bd0296 | ||
|
|
6462864b83 | ||
|
|
8b24acdc57 | ||
|
|
ca79f8ccc4 | ||
|
|
99e2cfff27 | ||
|
|
a3dcd9882f | ||
|
|
9fe0a9cac4 | ||
|
|
03af690b60 | ||
|
|
90818e82b3 | ||
|
|
430c557b6a | ||
|
|
7354055077 | ||
|
|
cd60b3d151 | ||
|
|
857735f15d | ||
|
|
2e18d2c51a | ||
|
|
d5d2910654 |
@@ -106,8 +106,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
@@ -185,8 +183,6 @@ video, audio = pipe(
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
encode_video(
|
||||
video[0],
|
||||
|
||||
@@ -53,6 +53,41 @@ image = pipe(
|
||||
image.save("zimage_img2img.png")
|
||||
```
|
||||
|
||||
## Inpainting
|
||||
|
||||
Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import ZImageInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
# Create a mask (white = inpaint, black = preserve)
|
||||
mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
mask_image = Image.fromarray(mask)
|
||||
|
||||
prompt = "A beautiful lake with mountains in the background"
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=1.0,
|
||||
num_inference_steps=9,
|
||||
guidance_scale=0.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).images[0]
|
||||
image.save("zimage_inpaint.png")
|
||||
```
|
||||
|
||||
## ZImagePipeline
|
||||
|
||||
[[autodoc]] ZImagePipeline
|
||||
@@ -64,3 +99,9 @@ image.save("zimage_img2img.png")
|
||||
[[autodoc]] ZImageImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ZImageInpaintPipeline
|
||||
|
||||
[[autodoc]] ZImageInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -111,3 +111,57 @@ config = TaylorSeerCacheConfig(
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## MagCache
|
||||
|
||||
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
|
||||
|
||||
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
|
||||
|
||||
### Usage
|
||||
|
||||
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
|
||||
|
||||
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
|
||||
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, MagCacheConfig
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# 1. Calibration Step
|
||||
# Run full inference to measure model behavior.
|
||||
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
|
||||
pipe.transformer.enable_cache(calib_config)
|
||||
|
||||
# Run a prompt to trigger calibration
|
||||
pipe("A cat playing chess", num_inference_steps=4)
|
||||
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
|
||||
|
||||
# 2. Inference Step
|
||||
# Apply the specific ratios obtained from calibration for optimized speed.
|
||||
# Note: For Flux models, you can also import defaults:
|
||||
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
|
||||
mag_config = MagCacheConfig(
|
||||
mag_ratios=[1.0, 1.37, 0.97, 0.87],
|
||||
num_inference_steps=4
|
||||
)
|
||||
|
||||
pipe.transformer.enable_cache(mag_config)
|
||||
|
||||
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
|
||||
|
||||
@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
||||
@@ -168,12 +168,14 @@ else:
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"MagCacheConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
@@ -694,6 +696,7 @@ else:
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
@@ -932,12 +935,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -1424,6 +1429,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .mag_cache import MagCacheConfig, apply_mag_cache
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
@@ -23,7 +23,13 @@ from ..models.attention_processor import Attention, MochiAttention
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"layers",
|
||||
"visual_transformer_blocks",
|
||||
)
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class AttentionProcessorMetadata:
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
hidden_states_argument_name: str = "hidden_states"
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
@@ -169,7 +170,7 @@ def _register_attention_processors_metadata():
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
@@ -184,6 +185,7 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -331,6 +333,24 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=JointTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=Kandinsky5TransformerDecoderBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
hidden_states_argument_name="visual_embed",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
468
src/diffusers/hooks/mag_cache.py
Normal file
468
src/diffusers/hooks/mag_cache.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
|
||||
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
|
||||
|
||||
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
|
||||
# Users must explicitly pass these to the config if using Flux.
|
||||
# Reference: https://github.com/Zehong-Ma/MagCache
|
||||
FLUX_MAG_RATIOS = torch.tensor(
|
||||
[1.0]
|
||||
+ [
|
||||
1.21094,
|
||||
1.11719,
|
||||
1.07812,
|
||||
1.0625,
|
||||
1.03906,
|
||||
1.03125,
|
||||
1.03906,
|
||||
1.02344,
|
||||
1.03125,
|
||||
1.02344,
|
||||
0.98047,
|
||||
1.01562,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.0,
|
||||
0.99609,
|
||||
0.99609,
|
||||
0.98047,
|
||||
0.98828,
|
||||
0.96484,
|
||||
0.95703,
|
||||
0.93359,
|
||||
0.89062,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate the source array to the target length using nearest neighbor interpolation.
|
||||
"""
|
||||
src_length = len(src_array)
|
||||
if target_length == 1:
|
||||
return src_array[-1:]
|
||||
|
||||
scale = (src_length - 1) / (target_length - 1)
|
||||
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
|
||||
mapped_indices = torch.round(grid * scale).long()
|
||||
return src_array[mapped_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MagCacheConfig:
|
||||
r"""
|
||||
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.06`):
|
||||
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
|
||||
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
|
||||
quality.
|
||||
max_skip_steps (`int`, defaults to `3`):
|
||||
The maximum number of consecutive steps that can be skipped (K in the paper).
|
||||
retention_ratio (`float`, defaults to `0.2`):
|
||||
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
|
||||
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
|
||||
num_inference_steps (`int`, defaults to `28`):
|
||||
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
|
||||
mag_ratios (`torch.Tensor`, *optional*):
|
||||
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
|
||||
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
|
||||
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
|
||||
calibrate (`bool`, defaults to `False`):
|
||||
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
|
||||
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
|
||||
models or schedulers.
|
||||
"""
|
||||
|
||||
threshold: float = 0.06
|
||||
max_skip_steps: int = 3
|
||||
retention_ratio: float = 0.2
|
||||
num_inference_steps: int = 28
|
||||
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
|
||||
calibrate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# User MUST provide ratios OR enable calibration.
|
||||
if self.mag_ratios is None and not self.calibrate:
|
||||
raise ValueError(
|
||||
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
|
||||
"To get them for your model:\n"
|
||||
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
|
||||
"2. Run inference on your model once.\n"
|
||||
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
|
||||
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
|
||||
)
|
||||
|
||||
if not self.calibrate and self.mag_ratios is not None:
|
||||
if not torch.is_tensor(self.mag_ratios):
|
||||
self.mag_ratios = torch.tensor(self.mag_ratios)
|
||||
|
||||
if len(self.mag_ratios) != self.num_inference_steps:
|
||||
logger.debug(
|
||||
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
|
||||
)
|
||||
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
|
||||
|
||||
|
||||
class MagCacheState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Cache for the residual (output - input) from the *previous* timestep
|
||||
self.previous_residual: torch.Tensor = None
|
||||
|
||||
# State inputs/outputs for the current forward pass
|
||||
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
# MagCache accumulators
|
||||
self.accumulated_ratio: float = 1.0
|
||||
self.accumulated_err: float = 0.0
|
||||
self.accumulated_steps: int = 0
|
||||
|
||||
# Current step counter (timestep index)
|
||||
self.step_index: int = 0
|
||||
|
||||
# Calibration storage
|
||||
self.calibration_ratios: List[float] = []
|
||||
|
||||
def reset(self):
|
||||
self.previous_residual = None
|
||||
self.should_compute = True
|
||||
self.accumulated_ratio = 1.0
|
||||
self.accumulated_err = 0.0
|
||||
self.accumulated_steps = 0
|
||||
self.step_index = 0
|
||||
self.calibration_ratios = []
|
||||
|
||||
|
||||
class MagCacheHeadHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
|
||||
self.state_manager = state_manager
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
state.head_block_input = hidden_states
|
||||
|
||||
should_compute = True
|
||||
|
||||
if self.config.calibrate:
|
||||
# Never skip during calibration
|
||||
should_compute = True
|
||||
else:
|
||||
# MagCache Logic
|
||||
current_step = state.step_index
|
||||
if current_step >= len(self.config.mag_ratios):
|
||||
current_scale = 1.0
|
||||
else:
|
||||
current_scale = self.config.mag_ratios[current_step]
|
||||
|
||||
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
|
||||
|
||||
if current_step >= retention_step:
|
||||
state.accumulated_ratio *= current_scale
|
||||
state.accumulated_steps += 1
|
||||
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
|
||||
|
||||
if (
|
||||
state.previous_residual is not None
|
||||
and state.accumulated_err <= self.config.threshold
|
||||
and state.accumulated_steps <= self.config.max_skip_steps
|
||||
):
|
||||
should_compute = False
|
||||
else:
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
|
||||
state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
logger.debug(f"MagCache: Skipping step {state.step_index}")
|
||||
# Apply MagCache: Output = Input + Previous Residual
|
||||
|
||||
output = hidden_states
|
||||
res = state.previous_residual
|
||||
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
|
||||
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
|
||||
if res.shape == output.shape:
|
||||
output = output + res
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = output
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
else:
|
||||
return output
|
||||
|
||||
else:
|
||||
# Compute original forward
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class MagCacheBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
|
||||
if not state.should_compute:
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Still need to advance step index even if we skip
|
||||
self._advance_step(state)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = hidden_states
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
|
||||
return hidden_states
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Calculate residual for next steps
|
||||
if isinstance(output, tuple):
|
||||
out_hidden = output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
out_hidden = output
|
||||
|
||||
in_hidden = state.head_block_input
|
||||
|
||||
if in_hidden is None:
|
||||
return output
|
||||
|
||||
# Determine residual
|
||||
if out_hidden.shape == in_hidden.shape:
|
||||
residual = out_hidden - in_hidden
|
||||
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
|
||||
diff = in_hidden.shape[1] - out_hidden.shape[1]
|
||||
if diff == 0:
|
||||
residual = out_hidden - in_hidden
|
||||
else:
|
||||
residual = out_hidden - in_hidden # Fallback to matching tail
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
|
||||
if self.config.calibrate:
|
||||
self._perform_calibration_step(state, residual)
|
||||
|
||||
state.previous_residual = residual
|
||||
self._advance_step(state)
|
||||
|
||||
return output
|
||||
|
||||
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
|
||||
if state.previous_residual is None:
|
||||
# First step has no previous residual to compare against.
|
||||
# log 1.0 as a neutral starting point.
|
||||
ratio = 1.0
|
||||
else:
|
||||
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
|
||||
# norm(dim=-1) gives magnitude of each token vector
|
||||
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
|
||||
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
|
||||
|
||||
# Avoid division by zero
|
||||
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
|
||||
|
||||
state.calibration_ratios.append(ratio)
|
||||
|
||||
def _advance_step(self, state: MagCacheState):
|
||||
state.step_index += 1
|
||||
if state.step_index >= self.config.num_inference_steps:
|
||||
# End of inference loop
|
||||
if self.config.calibrate:
|
||||
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
|
||||
print(f"{state.calibration_ratios}\n")
|
||||
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
|
||||
|
||||
# Reset state
|
||||
state.step_index = 0
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
state.previous_residual = None
|
||||
state.calibration_ratios = []
|
||||
|
||||
|
||||
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
|
||||
"""
|
||||
Applies MagCache to a given module (typically a Transformer).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply MagCache to.
|
||||
config (`MagCacheConfig`):
|
||||
The configuration for MagCache.
|
||||
"""
|
||||
# Initialize registry on the root module so the Pipeline can set context.
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(MagCacheState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
if not remaining_blocks:
|
||||
logger.warning("MagCache: No transformer blocks found to apply hooks.")
|
||||
return
|
||||
|
||||
# Handle single-block models
|
||||
if len(remaining_blocks) == 1:
|
||||
name, block = remaining_blocks[0]
|
||||
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
|
||||
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
|
||||
_apply_mag_cache_head_hook(block, state_manager, config)
|
||||
return
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
|
||||
_apply_mag_cache_head_hook(head_block, state_manager, config)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
_apply_mag_cache_block_hook(block, state_manager, config)
|
||||
|
||||
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
|
||||
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
|
||||
|
||||
|
||||
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application (e.g. switching modes)
|
||||
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheHeadHook(state_manager, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_mag_cache_block_hook(
|
||||
block: torch.nn.Module,
|
||||
state_manager: StateManager,
|
||||
config: MagCacheConfig,
|
||||
is_tail: bool = False,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application
|
||||
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheBlockHook(state_manager, is_tail, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
|
||||
@@ -68,10 +68,12 @@ class CacheMixin:
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -85,6 +87,8 @@ class CacheMixin:
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -99,11 +103,13 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
@@ -118,6 +124,9 @@ class CacheMixin:
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
|
||||
@@ -125,9 +125,9 @@ class BriaFiboAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -130,9 +130,9 @@ class FluxAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -561,11 +561,11 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
||||
|
||||
# Apply output projections
|
||||
img_attn_output = attn.to_out[0](img_attn_output)
|
||||
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
||||
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
@@ -410,11 +410,12 @@ else:
|
||||
"Kandinsky5I2IPipeline",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -870,6 +871,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -127,6 +127,7 @@ from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
@@ -235,6 +236,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
("qwenimage", QwenImageInpaintPipeline),
|
||||
("qwenimage-edit", QwenImageEditInpaintPipeline),
|
||||
("z-image", ZImageInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -13,10 +13,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...utils import is_av_available
|
||||
|
||||
@@ -101,11 +106,54 @@ def _write_audio(
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
video: Union[List[PIL.Image.Image], np.ndarray, torch.Tensor, Iterator[torch.Tensor]],
|
||||
fps: int,
|
||||
audio: Optional[torch.Tensor],
|
||||
audio_sample_rate: Optional[int],
|
||||
output_path: str,
|
||||
video_chunks_number: int = 1,
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
"""
|
||||
Encodes a video with audio using the PyAV library. Based on code from the original LTX-2 repo:
|
||||
https://github.com/Lightricks/LTX-2/blob/4f410820b198e05074a1e92de793e3b59e9ab5a0/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L182
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
Args:
|
||||
video (`List[PIL.Image.Image]` or `np.ndarray` or `torch.Tensor`):
|
||||
A video tensor of shape [frames, height, width, channels] with integer pixel values in [0, 255]. If the
|
||||
input is a `np.ndarray`, it is expected to be a float array with values in [0, 1] (which is what pipelines
|
||||
usually return with `output_type="np"`).
|
||||
fps (`int`)
|
||||
The frames per second (FPS) of the encoded video.
|
||||
audio (`torch.Tensor`, *optional*):
|
||||
An audio waveform of shape [audio_channels, samples].
|
||||
audio_sample_rate: (`int`, *optional*):
|
||||
The sampling rate of the audio waveform. For LTX 2, this is typically 24000 (24 kHz).
|
||||
output_path (`str`):
|
||||
The path to save the encoded video to.
|
||||
video_chunks_number (`int`, *optional*, defaults to `1`):
|
||||
The number of chunks to split the video into for encoding. Each chunk will be encoded separately. The
|
||||
number of chunks to use often depends on the tiling config for the video VAE.
|
||||
"""
|
||||
if isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
|
||||
# Pipeline output_type="pil"; assumes each image is in "RGB" mode
|
||||
video_frames = [np.array(frame) for frame in video]
|
||||
video = np.stack(video_frames, axis=0)
|
||||
video = torch.from_numpy(video)
|
||||
elif isinstance(video, np.ndarray):
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
if isinstance(video, torch.Tensor):
|
||||
# Split into video_chunks_number along the frame dimension
|
||||
video = torch.tensor_split(video, video_chunks_number, dim=0)
|
||||
video = iter(video)
|
||||
|
||||
first_chunk = next(video)
|
||||
|
||||
_, height, width, _ = first_chunk.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
@@ -119,10 +167,12 @@ def encode_video(
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
for video_chunk in tqdm(chain([first_chunk], video), total=video_chunks_number, desc="Encoding video chunks"):
|
||||
video_chunk_cpu = video_chunk.to("cpu").numpy()
|
||||
for frame_array in video_chunk_cpu:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
|
||||
@@ -69,8 +69,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -75,8 +75,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -76,8 +76,6 @@ EXAMPLE_DOC_STRING = """
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
|
||||
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
@@ -42,6 +43,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
from .pipeline_z_image_inpaint import ZImageInpaintPipeline
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
932
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Normal file
932
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Normal file
@@ -0,0 +1,932 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageInpaintPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
>>> # Create a mask (white = inpaint, black = preserve)
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
>>> mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
>>> mask_image = Image.fromarray(mask)
|
||||
|
||||
>>> prompt = "A beautiful lake with mountains in the background"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... image=init_image,
|
||||
... mask_image=mask_image,
|
||||
... strength=1.0,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage_inpaint.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The ZImage pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`PreTrainedModel`]):
|
||||
A text encoder model to encode text prompts.
|
||||
tokenizer ([`AutoTokenizer`]):
|
||||
A tokenizer to tokenize text prompts.
|
||||
transformer ([`ZImageTransformer2DModel`]):
|
||||
A ZImage transformer model to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
"""Prepare mask and masked image latents for inpainting.
|
||||
|
||||
Args:
|
||||
mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region.
|
||||
masked_image: Original image with masked regions zeroed out.
|
||||
batch_size: Number of images to generate.
|
||||
height: Output image height.
|
||||
width: Output image width.
|
||||
dtype: Data type for the tensors.
|
||||
device: Device to place tensors on.
|
||||
generator: Random generator for reproducibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, masked_image_latents) prepared for the denoising loop.
|
||||
"""
|
||||
# Calculate latent dimensions
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
# Resize mask to latent dimensions
|
||||
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest")
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode masked image to latents
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i])
|
||||
for i in range(masked_image.shape[0])
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
# Apply VAE scaling
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# Expand for batch size
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
"""Prepare latents for inpainting, returning noise and image_latents for blending.
|
||||
|
||||
Returns:
|
||||
Tuple of (latents, noise, image_latents) where:
|
||||
- latents: Noised image latents for denoising
|
||||
- noise: The noise tensor used for blending
|
||||
- image_latents: Clean image latents for blending
|
||||
"""
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
# Generate noise for blending even if latents are provided
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# Encode image for blending
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0)
|
||||
return latents.to(device=device, dtype=dtype), noise, image_latents
|
||||
|
||||
# Encode the input image
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != num_channels_latents:
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
# Handle batch size expansion
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
# Generate noise for both initial noising and later blending
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# Add noise using flow matching scale_noise
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
||||
|
||||
return latents, noise, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined for inpainting.")
|
||||
|
||||
if mask_image is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined for inpainting.")
|
||||
|
||||
if output_type not in ["latent", "pil", "np", "pt"]:
|
||||
raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
strength: float = 1.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
|
||||
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
|
||||
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the
|
||||
mask will be inpainted, black pixels (value 0) will be preserved from the original image.
|
||||
masked_image_latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image` in the masked region.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image. If not provided, uses the input image height.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image. If not provided, uses the input image width.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type=output_type,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Preprocess image and mask
|
||||
init_image = self.image_processor.preprocess(image)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# Get dimensions from the preprocessed image if not specified
|
||||
if height is None:
|
||||
height = init_image.shape[-2]
|
||||
if width is None:
|
||||
width = init_image.shape[-1]
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# Preprocess mask
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
# Calculate latent dimensions for image_seq_len
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
|
||||
# 6. Adjust timesteps based on strength
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1].repeat(actual_batch_size)
|
||||
|
||||
# 7. Prepare latents from image (returns noise and image_latents for blending)
|
||||
latents, noise, image_latents = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
actual_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 8. Prepare mask and masked image latents
|
||||
# Create masked image: preserve only unmasked regions (mask=0)
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
else:
|
||||
masked_image = None # Will use provided masked_image_latents
|
||||
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image if masked_image is not None else init_image,
|
||||
actual_batch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
# Inpainting blend: combine denoised latents with original image latents
|
||||
init_latents_proper = image_latents
|
||||
|
||||
# Re-scale original latents to current noise level for proper blending
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.scale_noise(
|
||||
init_latents_proper, torch.tensor([noise_timestep]), noise
|
||||
)
|
||||
|
||||
# Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original)
|
||||
latents = (1 - mask) * init_latents_proper + mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -79,7 +79,8 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
weight = dequantize_gguf_tensor(qweight)
|
||||
return x @ weight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
|
||||
@@ -545,7 +545,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -867,7 +867,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -245,13 +245,26 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -259,7 +272,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -287,7 +308,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -724,7 +750,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -738,7 +764,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -822,7 +848,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -832,8 +858,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -860,7 +888,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -891,7 +922,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -901,7 +932,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -1014,7 +1045,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -1024,8 +1055,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -1106,7 +1139,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
@@ -1216,7 +1251,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
||||
|
||||
@@ -141,6 +141,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The flow shift factor. Valid only when `use_flow_sigmas=True`.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -163,15 +167,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@@ -180,19 +184,32 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -200,7 +217,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -219,7 +244,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -250,7 +280,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -382,7 +416,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -419,7 +453,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
@@ -441,7 +475,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -567,7 +601,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -581,7 +615,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -666,7 +700,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -676,8 +710,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -704,7 +740,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -736,7 +775,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -746,7 +785,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -860,7 +899,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -870,8 +909,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -951,7 +992,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
@@ -975,7 +1016,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -1027,7 +1068,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise
|
||||
@@ -1074,6 +1118,21 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the clean `original_samples` using the scheduler's equivalent function.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples to add noise to.
|
||||
noise (`torch.Tensor`):
|
||||
The noise tensor.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
@@ -1103,5 +1162,5 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -1120,7 +1120,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -662,7 +662,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1122,7 +1122,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1083,7 +1083,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -227,6 +227,21 @@ class LayerSkipConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MagCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -284,6 +299,10 @@ def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_mag_cache(*args, **kwargs):
|
||||
requires_backends(apply_mag_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -4112,6 +4112,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
244
tests/hooks/test_mag_cache.py
Normal file
244
tests/hooks/test_mag_cache.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import MagCacheConfig, apply_mag_cache
|
||||
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Output is double input
|
||||
# This ensures Residual = 2*Input - Input = Input
|
||||
return hidden_states * 2.0
|
||||
|
||||
|
||||
class DummyTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TupleOutputBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Returns a tuple
|
||||
return hidden_states * 2.0, encoder_hidden_states
|
||||
|
||||
|
||||
class TupleTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
# Emulate Flux-like behavior
|
||||
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = output[0]
|
||||
encoder_hidden_states = output[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Register standard dummy block
|
||||
TransformerBlockRegistry.register(
|
||||
DummyBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
|
||||
)
|
||||
# Register tuple block (Flux style)
|
||||
TransformerBlockRegistry.register(
|
||||
TupleOutputBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
|
||||
)
|
||||
|
||||
def _set_context(self, model, context_name):
|
||||
"""Helper to set context on all hooks in the model."""
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(context_name)
|
||||
|
||||
def _get_calibration_data(self, model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
|
||||
if hook:
|
||||
return hook.state_manager.get_state().calibration_ratios
|
||||
return []
|
||||
|
||||
def test_mag_cache_validation(self):
|
||||
"""Test that missing mag_ratios raises ValueError."""
|
||||
with self.assertRaises(ValueError):
|
||||
MagCacheConfig(num_inference_steps=10, calibrate=False)
|
||||
|
||||
def test_mag_cache_skipping_logic(self):
|
||||
"""
|
||||
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
|
||||
"""
|
||||
model = DummyTransformer()
|
||||
|
||||
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=0.0, # Enable immediate skipping
|
||||
max_skip_steps=5,
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
|
||||
# HeadInput=10. Output=40. Residual=30.
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
output_t0 = model(input_t0)
|
||||
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
|
||||
|
||||
# Step 1: Input 11.0.
|
||||
# If Skipped: Output = Input(11) + Residual(30) = 41.0
|
||||
# If Computed: Output = 11 * 4 = 44.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_retention(self):
|
||||
"""Test that retention_ratio prevents skipping even if error is low."""
|
||||
model = DummyTransformer()
|
||||
# Ratios that imply 0 error, so it *would* skip if retention allowed it
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=1.0, # Force retention for ALL steps
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
|
||||
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
|
||||
)
|
||||
|
||||
def test_mag_cache_tuple_outputs(self):
|
||||
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
|
||||
model = TupleTransformer()
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
|
||||
# Residual = 10.0
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
enc_t0 = torch.tensor([[[1.0]]])
|
||||
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
|
||||
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
|
||||
|
||||
# Step 1: Skip. Input 11.0.
|
||||
# Skipped Output = 11 + 10 = 21.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_reset(self):
|
||||
"""Test that state resets correctly after num_inference_steps."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
|
||||
)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
input_t = torch.ones(1, 1, 1)
|
||||
|
||||
model(input_t) # Step 0
|
||||
model(input_t) # Step 1 (Skipped)
|
||||
|
||||
# Step 2 (Reset -> Step 0) -> Should Compute
|
||||
# Input 2.0 -> Output 8.0
|
||||
input_t2 = torch.tensor([[[2.0]]])
|
||||
output_t2 = model(input_t2)
|
||||
|
||||
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
|
||||
|
||||
def test_mag_cache_calibration(self):
|
||||
"""Test that calibration mode records ratios."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# Ratio 0 is placeholder 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Check intermediate state
|
||||
ratios = self._get_calibration_data(model)
|
||||
self.assertEqual(len(ratios), 1)
|
||||
self.assertEqual(ratios[0], 1.0)
|
||||
|
||||
# Step 1
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# PrevResidual = 30. CurrResidual = 30.
|
||||
# Ratio = 30/30 = 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Verify it computes fully (no skip)
|
||||
# If it skipped, output would be 41.0. It should be 40.0
|
||||
# Actually in test setup, input is same (10.0) so output 40.0.
|
||||
# Let's ensure list is empty after reset (end of step 1)
|
||||
ratios_after = self._get_calibration_data(model)
|
||||
self.assertEqual(ratios_after, [])
|
||||
@@ -27,6 +27,7 @@ from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
@@ -41,6 +42,7 @@ class FluxPipelineFastTests(
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers import (
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.mag_cache import MagCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -2976,6 +2977,59 @@ class TaylorSeerCacheTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class MagCacheTesterMixin:
|
||||
mag_cache_config = MagCacheConfig(
|
||||
threshold=0.06,
|
||||
max_skip_steps=3,
|
||||
retention_ratio=0.2,
|
||||
num_inference_steps=50,
|
||||
mag_ratios=torch.ones(50),
|
||||
)
|
||||
|
||||
def test_mag_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu"
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Match the config steps
|
||||
inputs["num_inference_steps"] = 50
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# 1. Run inference without MagCache (Baseline)
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 2. Run inference with MagCache ENABLED
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.mag_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 3. Run inference with MagCache DISABLED
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
|
||||
"MagCache outputs should not differ too much from baseline."
|
||||
)
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
|
||||
"Outputs after disabling cache should match original inference exactly."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
396
tests/pipelines/z_image/test_z_image_inpaint.py
Normal file
396
tests/pipelines/z_image/test_z_image_inpaint.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class ZImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ZImageInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(["image", "mask_image"])
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"strength",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = ZImageTransformer2DModel(
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=32,
|
||||
n_layers=2,
|
||||
n_refiner_layers=1,
|
||||
n_heads=2,
|
||||
n_kv_heads=2,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=16,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty` which contains
|
||||
# uninitialized memory. Set them to known values for deterministic test behavior.
|
||||
with torch.no_grad():
|
||||
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
|
||||
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
block_out_channels=[32, 64],
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=32,
|
||||
sample_size=32,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3Config(
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
text_encoder = Qwen3Model(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
import random
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
# Create mask: 1 = inpaint region, 0 = preserve region
|
||||
mask_image = torch.zeros((1, 1, 32, 32), device=device)
|
||||
mask_image[:, :, 8:24, 8:24] = 1.0 # Inpaint center region
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"strength": 1.0,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"cfg_normalization": False,
|
||||
"cfg_truncation": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (32, 32, 3))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.7):
|
||||
import random
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
# Generate a larger image for the input
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
# Generate a larger mask for the input
|
||||
mask = torch.zeros((1, 1, 128, 128), device="cpu")
|
||||
mask[:, :, 32:96, 32:96] = 1.0
|
||||
inputs["mask_image"] = mask
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
||||
pipe.vae.enable_tiling()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
inputs["mask_image"] = mask
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-3):
|
||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
||||
# Inpainting mask blending adds additional numerical variance
|
||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_group_offloading_inference(self):
|
||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
||||
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_strength_parameter(self):
|
||||
"""Test that strength parameter affects the output correctly."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Test with different strength values
|
||||
inputs_low_strength = self.get_dummy_inputs(device)
|
||||
inputs_low_strength["strength"] = 0.2
|
||||
|
||||
inputs_high_strength = self.get_dummy_inputs(device)
|
||||
inputs_high_strength["strength"] = 0.8
|
||||
|
||||
# Both should complete without errors
|
||||
output_low = pipe(**inputs_low_strength).images[0]
|
||||
output_high = pipe(**inputs_high_strength).images[0]
|
||||
|
||||
# Outputs should be different (different amount of transformation)
|
||||
self.assertFalse(np.allclose(output_low, output_high, atol=1e-3))
|
||||
|
||||
def test_invalid_strength(self):
|
||||
"""Test that invalid strength values raise appropriate errors."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
# Test strength < 0
|
||||
inputs["strength"] = -0.1
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
# Test strength > 1
|
||||
inputs["strength"] = 1.5
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_mask_inpainting(self):
|
||||
"""Test that the mask properly controls which regions are inpainted."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Generate with full mask (inpaint everything)
|
||||
inputs_full = self.get_dummy_inputs(device)
|
||||
inputs_full["mask_image"] = torch.ones((1, 1, 32, 32), device=device)
|
||||
|
||||
# Generate with no mask (preserve everything)
|
||||
inputs_none = self.get_dummy_inputs(device)
|
||||
inputs_none["mask_image"] = torch.zeros((1, 1, 32, 32), device=device)
|
||||
|
||||
# Both should complete without errors
|
||||
output_full = pipe(**inputs_full).images[0]
|
||||
output_none = pipe(**inputs_none).images[0]
|
||||
|
||||
# Outputs should be different (full inpaint vs preserve)
|
||||
self.assertFalse(np.allclose(output_full, output_none, atol=1e-3))
|
||||
Reference in New Issue
Block a user