Compare commits

..

1 Commits

Author SHA1 Message Date
Sayak Paul
8da128067c Fix syntax error in quantization configuration 2026-02-04 10:10:14 +05:30
31 changed files with 199 additions and 2488 deletions

View File

@@ -53,41 +53,6 @@ image = pipe(
image.save("zimage_img2img.png")
```
## Inpainting
Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.
```python
import torch
import numpy as np
from PIL import Image
from diffusers import ZImageInpaintPipeline
from diffusers.utils import load_image
pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
pipe.to("cuda")
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
init_image = load_image(url).resize((1024, 1024))
# Create a mask (white = inpaint, black = preserve)
mask = np.zeros((1024, 1024), dtype=np.uint8)
mask[256:768, 256:768] = 255 # Inpaint center region
mask_image = Image.fromarray(mask)
prompt = "A beautiful lake with mountains in the background"
image = pipe(
prompt,
image=init_image,
mask_image=mask_image,
strength=1.0,
num_inference_steps=9,
guidance_scale=0.0,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save("zimage_inpaint.png")
```
## ZImagePipeline
[[autodoc]] ZImagePipeline
@@ -99,9 +64,3 @@ image.save("zimage_inpaint.png")
[[autodoc]] ZImageImg2ImgPipeline
- all
- __call__
## ZImageInpaintPipeline
[[autodoc]] ZImageInpaintPipeline
- all
- __call__

View File

@@ -12,85 +12,179 @@ specific language governing permissions and limitations under the License.
# ComponentsManager
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading.
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
## Connect to a pipeline
## Add a component
Create a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
> [!TIP]
> The `collection` parameter is optional but makes it easier to organize and manage components.
<hfoptions id="create">
<hfoption id="from_pretrained">
```py
from diffusers import ModularPipeline, ComponentsManager
import torch
manager = ComponentsManager()
pipe = ModularPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", components_manager=manager)
pipe.load_components(torch_dtype=torch.bfloat16)
comp = ComponentsManager()
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
```
</hfoption>
<hfoption id="init_pipeline">
```py
from diffusers import ModularPipelineBlocks, ComponentsManager
import torch
manager = ComponentsManager()
blocks = ModularPipelineBlocks.from_pretrained("diffusers/Florence2-image-Annotator", trust_remote_code=True)
pipe= blocks.init_pipeline(components_manager=manager)
pipe.load_components(torch_dtype=torch.bfloat16)
from diffusers import ComponentsManager
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
components = ComponentsManager()
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
```
</hfoption>
</hfoptions>
Components loaded by the pipeline are automatically registered in the manager. You can inspect them right away.
## Inspect components
Print the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID.
The output below corresponds to the `from_pretrained` example above
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
```py
Components:
=============================================================================================================================
Models:
-----------------------------------------------------------------------------------------------------------------------------
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID
-----------------------------------------------------------------------------------------------------------------------------
text_encoder_140458257514752 | Qwen3Model | cpu | torch.bfloat16 | 7.49 | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null
vae_140458257515376 | AutoencoderKL | cpu | torch.bfloat16 | 0.16 | Tongyi-MAI/Z-Image-Turbo|vae|null|null
transformer_140458257515616 | ZImageTransformer2DModel | cpu | torch.bfloat16 | 11.46 | Tongyi-MAI/Z-Image-Turbo|transformer|null|null
-----------------------------------------------------------------------------------------------------------------------------
Other Components:
-----------------------------------------------------------------------------------------------------------------------------
ID | Class | Collection
-----------------------------------------------------------------------------------------------------------------------------
scheduler_140461023555264 | FlowMatchEulerDiscreteScheduler | N/A
tokenizer_140458256346432 | Qwen2Tokenizer | N/A
-----------------------------------------------------------------------------------------------------------------------------
pipe.load_components()
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
```
The table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom.
Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
```py
pipe2.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
pipe2.update_components(**comp_dict)
```
To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
```py
from diffusers import AutoModel
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
component_id = comp.add("text_encoder", text_encoder)
comp
```
Use [`~ComponentsManager.remove`] to remove a component using their id.
```py
comp.remove("text_encoder_139917733042864")
```
## Retrieve a component
The [`ComponentsManager`] provides several methods to retrieve registered components.
### get_one
The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
| Pattern | Example | Description |
|-------------|----------------------------------|-------------------------------------------|
| exact | `comp.get_one(name="unet")` | exact name match |
| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
| or | `comp.get_one(name="unet&#124;vae")` | name is "unet" or "vae" |
[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
```py
comp.get_one(name="unet", collection="sdxl")
```
### get_components_by_names
The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
```py
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
{"text_encoder": component1, "unet": component2, "vae": component3}
```
## Duplicate detection
It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
```py
from diffusers import ComponentSpec, ComponentsManager
from transformers import CLIPTextModel
comp = ComponentsManager()
# Create ComponentSpec for the first text encoder
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
# Load and add both components - the manager will detect they're the same model
comp.add("text_encoder", spec.load())
comp.add("text_encoder_duplicated", spec_duplicated.load())
```
This returns a warning with instructions for removing the duplicate.
```py
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
'text_encoder_duplicated_139917580682672'
```
You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
```py
text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
comp.add("text_encoder", text_encoder_2)
'text_encoder_139917732983664'
```
## Collections
Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
```py
from diffusers import ComponentSpec, ComponentsManager
comp = ComponentsManager()
# Create ComponentSpec for the first UNet
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
# Create ComponentSpec for a different UNet
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
# Add both UNets to the same collection - the second one will replace the first
comp.add("unet", spec.load(), collection="sdxl")
comp.add("unet", spec2.load(), collection="sdxl")
```
This makes it convenient to work with node-based systems because you can:
- Mark all models as loaded from one node with the `collection` label.
- Automatically replace models when new checkpoints are loaded under the same name.
- Batch delete all models in a collection when a node is removed.
## Offloading
The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
```py
manager.enable_auto_cpu_offload(device="cuda")
comp.enable_auto_cpu_offload(device="cuda")
```
All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
To disable offloading, call [~ComponentsManager.disable_auto_cpu_offload].
```py
manager.disable_auto_cpu_offload()
```
You can set your own rules for which models to offload first.

View File

@@ -111,57 +111,3 @@ config = TaylorSeerCacheConfig(
)
pipe.transformer.enable_cache(config)
```
## MagCache
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
### Usage
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
```python
import torch
from diffusers import FluxPipeline, MagCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16
).to("cuda")
# 1. Calibration Step
# Run full inference to measure model behavior.
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
pipe.transformer.enable_cache(calib_config)
# Run a prompt to trigger calibration
pipe("A cat playing chess", num_inference_steps=4)
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
# 2. Inference Step
# Apply the specific ratios obtained from calibration for optimized speed.
# Note: For Flux models, you can also import defaults:
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
mag_config = MagCacheConfig(
mag_ratios=[1.0, 1.37, 0.97, 0.87],
num_inference_steps=4
)
pipe.transformer.enable_cache(mag_config)
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
```
> [!NOTE]
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
> [!TIP]
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
> [!TIP]
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.

View File

@@ -168,14 +168,12 @@ else:
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"MagCacheConfig",
"PyramidAttentionBroadcastConfig",
"SmoothedEnergyGuidanceConfig",
"TaylorSeerCacheConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_mag_cache",
"apply_pyramid_attention_broadcast",
"apply_taylorseer_cache",
]
@@ -696,7 +694,6 @@ else:
"ZImageControlNetInpaintPipeline",
"ZImageControlNetPipeline",
"ZImageImg2ImgPipeline",
"ZImageInpaintPipeline",
"ZImageOmniPipeline",
"ZImagePipeline",
]
@@ -935,14 +932,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
@@ -1429,7 +1424,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageInpaintPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -23,7 +23,6 @@ if is_torch_available():
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .mag_cache import MagCacheConfig, apply_mag_cache
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache

View File

@@ -23,13 +23,7 @@ from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
"blocks",
"transformer_blocks",
"single_transformer_blocks",
"layers",
"visual_transformer_blocks",
)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")

View File

@@ -26,7 +26,6 @@ class AttentionProcessorMetadata:
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
hidden_states_argument_name: str = "hidden_states"
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
@@ -170,7 +169,7 @@ def _register_attention_processors_metadata():
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_bria import BriaTransformerBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
@@ -185,7 +184,6 @@ def _register_transformer_blocks_metadata():
HunyuanImageSingleTransformerBlock,
HunyuanImageTransformerBlock,
)
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -333,24 +331,6 @@ def _register_transformer_blocks_metadata():
),
)
TransformerBlockRegistry.register(
model_class=JointTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
TransformerBlockRegistry.register(
model_class=Kandinsky5TransformerDecoderBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
hidden_states_argument_name="visual_embed",
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):

View File

@@ -1,468 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
# Users must explicitly pass these to the config if using Flux.
# Reference: https://github.com/Zehong-Ma/MagCache
FLUX_MAG_RATIOS = torch.tensor(
[1.0]
+ [
1.21094,
1.11719,
1.07812,
1.0625,
1.03906,
1.03125,
1.03906,
1.02344,
1.03125,
1.02344,
0.98047,
1.01562,
1.00781,
1.0,
1.00781,
1.0,
1.00781,
1.0,
1.0,
0.99609,
0.99609,
0.98047,
0.98828,
0.96484,
0.95703,
0.93359,
0.89062,
]
)
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
"""
Interpolate the source array to the target length using nearest neighbor interpolation.
"""
src_length = len(src_array)
if target_length == 1:
return src_array[-1:]
scale = (src_length - 1) / (target_length - 1)
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
mapped_indices = torch.round(grid * scale).long()
return src_array[mapped_indices]
@dataclass
class MagCacheConfig:
r"""
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
Args:
threshold (`float`, defaults to `0.06`):
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
quality.
max_skip_steps (`int`, defaults to `3`):
The maximum number of consecutive steps that can be skipped (K in the paper).
retention_ratio (`float`, defaults to `0.2`):
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
num_inference_steps (`int`, defaults to `28`):
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
mag_ratios (`torch.Tensor`, *optional*):
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
calibrate (`bool`, defaults to `False`):
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
models or schedulers.
"""
threshold: float = 0.06
max_skip_steps: int = 3
retention_ratio: float = 0.2
num_inference_steps: int = 28
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
calibrate: bool = False
def __post_init__(self):
# User MUST provide ratios OR enable calibration.
if self.mag_ratios is None and not self.calibrate:
raise ValueError(
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
"To get them for your model:\n"
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
"2. Run inference on your model once.\n"
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
)
if not self.calibrate and self.mag_ratios is not None:
if not torch.is_tensor(self.mag_ratios):
self.mag_ratios = torch.tensor(self.mag_ratios)
if len(self.mag_ratios) != self.num_inference_steps:
logger.debug(
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
)
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
class MagCacheState(BaseState):
def __init__(self) -> None:
super().__init__()
# Cache for the residual (output - input) from the *previous* timestep
self.previous_residual: torch.Tensor = None
# State inputs/outputs for the current forward pass
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
# MagCache accumulators
self.accumulated_ratio: float = 1.0
self.accumulated_err: float = 0.0
self.accumulated_steps: int = 0
# Current step counter (timestep index)
self.step_index: int = 0
# Calibration storage
self.calibration_ratios: List[float] = []
def reset(self):
self.previous_residual = None
self.should_compute = True
self.accumulated_ratio = 1.0
self.accumulated_err = 0.0
self.accumulated_steps = 0
self.step_index = 0
self.calibration_ratios = []
class MagCacheHeadHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
self.state_manager = state_manager
self.config = config
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
@torch.compiler.disable
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
arg_name = self._metadata.hidden_states_argument_name
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
state: MagCacheState = self.state_manager.get_state()
state.head_block_input = hidden_states
should_compute = True
if self.config.calibrate:
# Never skip during calibration
should_compute = True
else:
# MagCache Logic
current_step = state.step_index
if current_step >= len(self.config.mag_ratios):
current_scale = 1.0
else:
current_scale = self.config.mag_ratios[current_step]
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
if current_step >= retention_step:
state.accumulated_ratio *= current_scale
state.accumulated_steps += 1
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
if (
state.previous_residual is not None
and state.accumulated_err <= self.config.threshold
and state.accumulated_steps <= self.config.max_skip_steps
):
should_compute = False
else:
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.should_compute = should_compute
if not should_compute:
logger.debug(f"MagCache: Skipping step {state.step_index}")
# Apply MagCache: Output = Input + Previous Residual
output = hidden_states
res = state.previous_residual
if res.device != output.device:
res = res.to(output.device)
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
if res.shape == output.shape:
output = output + res
elif (
output.ndim == 3
and res.ndim == 3
and output.shape[0] == res.shape[0]
and output.shape[2] == res.shape[2]
):
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
diff = output.shape[1] - res.shape[1]
if diff > 0:
output = output.clone()
output[:, diff:, :] = output[:, diff:, :] + res
else:
logger.warning(
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
"Cannot apply residual safely. Returning input without residual."
)
else:
logger.warning(
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
"Cannot apply residual safely. Returning input without residual."
)
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
max_idx = max(
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
)
ret_list = [None] * (max_idx + 1)
ret_list[self._metadata.return_hidden_states_index] = output
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return tuple(ret_list)
else:
return output
else:
# Compute original forward
output = self.fn_ref.original_forward(*args, **kwargs)
return output
def reset_state(self, module):
self.state_manager.reset()
return module
class MagCacheBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self.config = config
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
@torch.compiler.disable
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.state_manager._current_context is None:
self.state_manager.set_context("inference")
state: MagCacheState = self.state_manager.get_state()
if not state.should_compute:
arg_name = self._metadata.hidden_states_argument_name
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
if self.is_tail:
# Still need to advance step index even if we skip
self._advance_step(state)
if self._metadata.return_encoder_hidden_states_index is not None:
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
max_idx = max(
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
)
ret_list = [None] * (max_idx + 1)
ret_list[self._metadata.return_hidden_states_index] = hidden_states
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return tuple(ret_list)
return hidden_states
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
# Calculate residual for next steps
if isinstance(output, tuple):
out_hidden = output[self._metadata.return_hidden_states_index]
else:
out_hidden = output
in_hidden = state.head_block_input
if in_hidden is None:
return output
# Determine residual
if out_hidden.shape == in_hidden.shape:
residual = out_hidden - in_hidden
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
diff = in_hidden.shape[1] - out_hidden.shape[1]
if diff == 0:
residual = out_hidden - in_hidden
else:
residual = out_hidden - in_hidden # Fallback to matching tail
else:
# Fallback for completely mismatched shapes
residual = out_hidden
if self.config.calibrate:
self._perform_calibration_step(state, residual)
state.previous_residual = residual
self._advance_step(state)
return output
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
if state.previous_residual is None:
# First step has no previous residual to compare against.
# log 1.0 as a neutral starting point.
ratio = 1.0
else:
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
# norm(dim=-1) gives magnitude of each token vector
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
# Avoid division by zero
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
state.calibration_ratios.append(ratio)
def _advance_step(self, state: MagCacheState):
state.step_index += 1
if state.step_index >= self.config.num_inference_steps:
# End of inference loop
if self.config.calibrate:
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
print(f"{state.calibration_ratios}\n")
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
# Reset state
state.step_index = 0
state.accumulated_ratio = 1.0
state.accumulated_steps = 0
state.accumulated_err = 0.0
state.previous_residual = None
state.calibration_ratios = []
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
"""
Applies MagCache to a given module (typically a Transformer).
Args:
module (`torch.nn.Module`):
The module to apply MagCache to.
config (`MagCacheConfig`):
The configuration for MagCache.
"""
# Initialize registry on the root module so the Pipeline can set context.
HookRegistry.check_if_exists_or_initialize(module)
state_manager = StateManager(MagCacheState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
if not remaining_blocks:
logger.warning("MagCache: No transformer blocks found to apply hooks.")
return
# Handle single-block models
if len(remaining_blocks) == 1:
name, block = remaining_blocks[0]
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
_apply_mag_cache_head_hook(block, state_manager, config)
return
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
_apply_mag_cache_head_hook(head_block, state_manager, config)
for name, block in remaining_blocks:
_apply_mag_cache_block_hook(block, state_manager, config)
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
# Automatically remove existing hook to allow re-application (e.g. switching modes)
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
hook = MagCacheHeadHook(state_manager, config)
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
def _apply_mag_cache_block_hook(
block: torch.nn.Module,
state_manager: StateManager,
config: MagCacheConfig,
is_tail: bool = False,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
# Automatically remove existing hook to allow re-application
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
hook = MagCacheBlockHook(state_manager, is_tail, config)
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)

View File

@@ -68,12 +68,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
apply_faster_cache,
apply_first_block_cache,
apply_mag_cache,
apply_pyramid_attention_broadcast,
apply_taylorseer_cache,
)
@@ -87,8 +85,6 @@ class CacheMixin:
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, MagCacheConfig):
apply_mag_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, TaylorSeerCacheConfig):
@@ -103,13 +99,11 @@ class CacheMixin:
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
)
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
@@ -124,9 +118,6 @@ class CacheMixin:
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, MagCacheConfig):
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, TaylorSeerCacheConfig):

View File

@@ -125,9 +125,9 @@ class BriaFiboAttnProcessor:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states.contiguous())
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:

View File

@@ -130,9 +130,9 @@ class FluxAttnProcessor:
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
)
hidden_states = attn.to_out[0](hidden_states.contiguous())
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:

View File

@@ -561,11 +561,11 @@ class QwenDoubleStreamAttnProcessor2_0:
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
# Apply output projections
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
img_attn_output = attn.to_out[0](img_attn_output)
if len(attn.to_out) > 1:
img_attn_output = attn.to_out[1](img_attn_output) # dropout
txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
txt_attn_output = attn.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output

View File

@@ -410,12 +410,11 @@ else:
"Kandinsky5I2IPipeline",
]
_import_structure["z_image"] = [
"ZImageControlNetInpaintPipeline",
"ZImageControlNetPipeline",
"ZImageImg2ImgPipeline",
"ZImageInpaintPipeline",
"ZImageOmniPipeline",
"ZImagePipeline",
"ZImageControlNetPipeline",
"ZImageControlNetInpaintPipeline",
"ZImageOmniPipeline",
]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
@@ -871,7 +870,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageInpaintPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -127,7 +127,6 @@ from .z_image import (
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageInpaintPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)
@@ -236,7 +235,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
("qwenimage", QwenImageInpaintPipeline),
("qwenimage-edit", QwenImageEditInpaintPipeline),
("z-image", ZImageInpaintPipeline),
]
)

View File

@@ -26,7 +26,6 @@ else:
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"]
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
@@ -43,7 +42,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
from .pipeline_z_image_inpaint import ZImageInpaintPipeline
from .pipeline_z_image_omni import ZImageOmniPipeline
else:
import sys

View File

@@ -1,932 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import ZImageTransformer2DModel
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import ZImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import ZImageInpaintPipeline
>>> from diffusers.utils import load_image
>>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((1024, 1024))
>>> # Create a mask (white = inpaint, black = preserve)
>>> import numpy as np
>>> from PIL import Image
>>> mask = np.zeros((1024, 1024), dtype=np.uint8)
>>> mask[256:768, 256:768] = 255 # Inpaint center region
>>> mask_image = Image.fromarray(mask)
>>> prompt = "A beautiful lake with mountains in the background"
>>> image = pipe(
... prompt,
... image=init_image,
... mask_image=mask_image,
... strength=1.0,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(42),
... ).images[0]
>>> image.save("zimage_inpaint.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
r"""
The ZImage pipeline for inpainting.
Args:
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`PreTrainedModel`]):
A text encoder model to encode text prompts.
tokenizer ([`AutoTokenizer`]):
A tokenizer to tokenize text prompts.
transformer ([`ZImageTransformer2DModel`]):
A ZImage transformer model to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor * 2,
do_normalize=False,
do_binarize=True,
do_convert_grayscale=True,
)
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
device,
generator,
):
"""Prepare mask and masked image latents for inpainting.
Args:
mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region.
masked_image: Original image with masked regions zeroed out.
batch_size: Number of images to generate.
height: Output image height.
width: Output image width.
dtype: Data type for the tensors.
device: Device to place tensors on.
generator: Random generator for reproducibility.
Returns:
Tuple of (mask, masked_image_latents) prepared for the denoising loop.
"""
# Calculate latent dimensions
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
# Resize mask to latent dimensions
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest")
mask = mask.to(device=device, dtype=dtype)
# Encode masked image to latents
masked_image = masked_image.to(device=device, dtype=dtype)
if isinstance(generator, list):
masked_image_latents = [
retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i])
for i in range(masked_image.shape[0])
]
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
# Apply VAE scaling
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# Expand for batch size
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
return mask, masked_image_latents
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
"""Prepare latents for inpainting, returning noise and image_latents for blending.
Returns:
Tuple of (latents, noise, image_latents) where:
- latents: Noised image latents for denoising
- noise: The noise tensor used for blending
- image_latents: Clean image latents for blending
"""
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
# Generate noise for blending even if latents are provided
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# Encode image for blending
image = image.to(device=device, dtype=dtype)
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0)
return latents.to(device=device, dtype=dtype), noise, image_latents
# Encode the input image
image = image.to(device=device, dtype=dtype)
if image.shape[1] != num_channels_latents:
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
else:
image_latents = image
# Handle batch size expansion
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
# Generate noise for both initial noising and later blending
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# Add noise using flow matching scale_noise
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
return latents, noise, image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
def check_inputs(
self,
prompt,
image,
mask_image,
strength,
height,
width,
output_type,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if image is None:
raise ValueError("`image` input cannot be undefined for inpainting.")
if mask_image is None:
raise ValueError("`mask_image` input cannot be undefined for inpainting.")
if output_type not in ["latent", "pil", "np", "pt"]:
raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
masked_image_latents: Optional[torch.FloatTensor] = None,
strength: float = 1.0,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for inpainting.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the
mask will be inpainted, black pixels (value 0) will be preserved from the original image.
masked_image_latents (`torch.FloatTensor`, *optional*):
Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image` in the masked region.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image. If not provided, uses the input image height.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image. If not provided, uses the input image width.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
# 1. Check inputs
self.check_inputs(
prompt=prompt,
image=image,
mask_image=mask_image,
strength=strength,
height=height,
width=width,
output_type=output_type,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Preprocess image and mask
init_image = self.image_processor.preprocess(image)
init_image = init_image.to(dtype=torch.float32)
# Get dimensions from the preprocessed image if not specified
if height is None:
height = init_image.shape[-2]
if width is None:
width = init_image.shape[-1]
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
# Preprocess mask
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
# Calculate latent dimensions for image_seq_len
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
# 6. Adjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(actual_batch_size)
# 7. Prepare latents from image (returns noise and image_latents for blending)
latents, noise, image_latents = self.prepare_latents(
init_image,
latent_timestep,
actual_batch_size,
num_channels_latents,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
latents,
)
# 8. Prepare mask and masked image latents
# Create masked image: preserve only unmasked regions (mask=0)
if masked_image_latents is None:
masked_image = init_image * (mask < 0.5)
else:
masked_image = None # Will use provided masked_image_latents
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image if masked_image is not None else init_image,
actual_batch_size,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 9. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
# Inpainting blend: combine denoised latents with original image latents
init_latents_proper = image_latents
# Re-scale original latents to current noise level for proper blending
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper, torch.tensor([noise_timestep]), noise
)
# Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original)
latents = (1 - mask) * init_latents_proper + mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
mask = callback_outputs.pop("mask", mask)
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -79,8 +79,7 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
# there is no need to call any kernel for fp16/bf16
if qweight_type in UNQUANTIZED_TYPES:
weight = dequantize_gguf_tensor(qweight)
return x @ weight.T
return x @ qweight.T
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
# contiguous batching and inefficient with diffusers' batching,

View File

@@ -545,9 +545,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -867,9 +867,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -245,26 +245,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -272,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -308,12 +287,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
@@ -750,7 +724,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -764,7 +738,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -848,7 +822,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -858,10 +832,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -888,10 +860,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -922,7 +891,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -932,7 +901,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -1045,7 +1014,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -1055,10 +1024,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -1139,9 +1106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return x_t
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.
@@ -1251,10 +1216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
sample = sample.to(torch.float32)
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=torch.float32,
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)

View File

@@ -141,10 +141,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
flow_shift (`float`, *optional*, defaults to 1.0):
The flow shift factor. Valid only when `use_flow_sigmas=True`.
lambda_min_clipped (`float`, defaults to `-inf`):
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
cosine (`squaredcos_cap_v2`) noise schedule.
@@ -167,15 +163,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
beta_schedule: str = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
sample_max_value: float = 1.0,
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
solver_type: Literal["midpoint", "heun"] = "midpoint",
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
@@ -184,32 +180,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
use_flow_sigmas: Optional[bool] = False,
flow_shift: Optional[float] = 1.0,
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[Literal["learned", "learned_range"]] = None,
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
variance_type: Optional[str] = None,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
if (
sum(
[
self.config.use_beta_sigmas,
self.config.use_exponential_sigmas,
self.config.use_karras_sigmas,
]
)
> 1
):
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
raise ValueError(
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
)
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
deprecate(
"algorithm_types dpmsolver and sde-dpmsolver",
"1.0.0",
deprecation_message,
)
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -217,15 +200,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -244,12 +219,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self.init_noise_sigma = 1.0
# settings for DPM-Solver
if algorithm_type not in [
"dpmsolver",
"dpmsolver++",
"sde-dpmsolver",
"sde-dpmsolver++",
]:
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
if algorithm_type == "deis":
self.register_to_config(algorithm_type="dpmsolver++")
else:
@@ -280,11 +250,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"""
return self._step_index
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
):
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -416,7 +382,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return sample
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
def _sigma_to_t(self, sigma, log_sigmas):
"""
Convert sigma values to corresponding timestep values through interpolation.
@@ -453,7 +419,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return t
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _sigma_to_alpha_sigma_t(self, sigma):
"""
Convert sigma values to alpha_t and sigma_t values.
@@ -475,7 +441,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
return alpha_t, sigma_t
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
"""
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
Models](https://huggingface.co/papers/2206.00364).
@@ -601,7 +567,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
@@ -615,7 +581,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -700,7 +666,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output: torch.Tensor,
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -710,10 +676,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -740,10 +704,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -775,7 +736,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -785,7 +746,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
@@ -899,7 +860,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
self,
model_output_list: List[torch.Tensor],
*args,
sample: Optional[torch.Tensor] = None,
sample: torch.Tensor = None,
noise: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
@@ -909,10 +870,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output_list (`List[torch.Tensor]`):
The direct outputs from learned diffusion model at current and latter timesteps.
sample (`torch.Tensor`, *optional*):
sample (`torch.Tensor`):
A current instance of a sample created by diffusion process.
noise (`torch.Tensor`, *optional*):
The noise tensor.
Returns:
`torch.Tensor`:
@@ -992,7 +951,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
)
return x_t
def _init_step_index(self, timestep: Union[int, torch.Tensor]):
def _init_step_index(self, timestep):
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
@@ -1016,7 +975,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator: Optional[torch.Generator] = None,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
@@ -1068,10 +1027,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
noise = variance_noise
@@ -1118,21 +1074,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noise: torch.Tensor,
timesteps: torch.IntTensor,
) -> torch.Tensor:
"""
Add noise to the clean `original_samples` using the scheduler's equivalent function.
Args:
original_samples (`torch.Tensor`):
The original samples to add noise to.
noise (`torch.Tensor`):
The noise tensor.
timesteps (`torch.IntTensor`):
The timesteps at which to add noise.
Returns:
`torch.Tensor`:
The noisy samples.
"""
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -1162,5 +1103,5 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = alpha_t * original_samples + sigma_t * noise
return noisy_samples
def __len__(self) -> int:
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -1120,9 +1120,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -662,9 +662,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -1122,9 +1122,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -1083,9 +1083,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
def index_for_timestep(
self,
timestep: Union[int, torch.Tensor],
schedule_timesteps: Optional[torch.Tensor] = None,
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
) -> int:
"""
Find the index for a given timestep in the schedule.

View File

@@ -227,21 +227,6 @@ class LayerSkipConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MagCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -299,10 +284,6 @@ def apply_layer_skip(*args, **kwargs):
requires_backends(apply_layer_skip, ["torch"])
def apply_mag_cache(*args, **kwargs):
requires_backends(apply_mag_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -4112,21 +4112,6 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ZImageInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ZImageOmniPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -1,244 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import MagCacheConfig, apply_mag_cache
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.models import ModelMixin
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class DummyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Output is double input
# This ensures Residual = 2*Input - Input = Input
return hidden_states * 2.0
class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states
class TupleOutputBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Returns a tuple
return hidden_states * 2.0, encoder_hidden_states
class TupleTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
# Emulate Flux-like behavior
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = output[0]
encoder_hidden_states = output[1]
return hidden_states, encoder_hidden_states
class MagCacheTests(unittest.TestCase):
def setUp(self):
# Register standard dummy block
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)
# Register tuple block (Flux style)
TransformerBlockRegistry.register(
TupleOutputBlock,
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
)
def _set_context(self, model, context_name):
"""Helper to set context on all hooks in the model."""
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(context_name)
def _get_calibration_data(self, model):
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
if hook:
return hook.state_manager.get_state().calibration_ratios
return []
def test_mag_cache_validation(self):
"""Test that missing mag_ratios raises ValueError."""
with self.assertRaises(ValueError):
MagCacheConfig(num_inference_steps=10, calibrate=False)
def test_mag_cache_skipping_logic(self):
"""
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
"""
model = DummyTransformer()
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=0.0, # Enable immediate skipping
max_skip_steps=5,
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
# HeadInput=10. Output=40. Residual=30.
input_t0 = torch.tensor([[[10.0]]])
output_t0 = model(input_t0)
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
# Step 1: Input 11.0.
# If Skipped: Output = Input(11) + Residual(30) = 41.0
# If Computed: Output = 11 * 4 = 44.0
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
)
def test_mag_cache_retention(self):
"""Test that retention_ratio prevents skipping even if error is low."""
model = DummyTransformer()
# Ratios that imply 0 error, so it *would* skip if retention allowed it
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=1.0, # Force retention for ALL steps
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
model(torch.tensor([[[10.0]]]))
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
)
def test_mag_cache_tuple_outputs(self):
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
model = TupleTransformer()
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
# Residual = 10.0
input_t0 = torch.tensor([[[10.0]]])
enc_t0 = torch.tensor([[[1.0]]])
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
# Step 1: Skip. Input 11.0.
# Skipped Output = 11 + 10 = 21.0
input_t1 = torch.tensor([[[11.0]]])
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
self.assertTrue(
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
)
def test_mag_cache_reset(self):
"""Test that state resets correctly after num_inference_steps."""
model = DummyTransformer()
config = MagCacheConfig(
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
input_t = torch.ones(1, 1, 1)
model(input_t) # Step 0
model(input_t) # Step 1 (Skipped)
# Step 2 (Reset -> Step 0) -> Should Compute
# Input 2.0 -> Output 8.0
input_t2 = torch.tensor([[[2.0]]])
output_t2 = model(input_t2)
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
def test_mag_cache_calibration(self):
"""Test that calibration mode records ratios."""
model = DummyTransformer()
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
# HeadInput = 10. Output = 40. Residual = 30.
# Ratio 0 is placeholder 1.0
model(torch.tensor([[[10.0]]]))
# Check intermediate state
ratios = self._get_calibration_data(model)
self.assertEqual(len(ratios), 1)
self.assertEqual(ratios[0], 1.0)
# Step 1
# HeadInput = 10. Output = 40. Residual = 30.
# PrevResidual = 30. CurrResidual = 30.
# Ratio = 30/30 = 1.0
model(torch.tensor([[[10.0]]]))
# Verify it computes fully (no skip)
# If it skipped, output would be 41.0. It should be 40.0
# Actually in test setup, input is same (10.0) so output 40.0.
# Let's ensure list is empty after reset (end of step 1)
ratios_after = self._get_calibration_data(model)
self.assertEqual(ratios_after, [])

View File

@@ -27,7 +27,6 @@ from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
MagCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheTesterMixin,
@@ -42,7 +41,6 @@ class FluxPipelineFastTests(
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
TaylorSeerCacheTesterMixin,
MagCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline

View File

@@ -35,7 +35,6 @@ from diffusers import (
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.mag_cache import MagCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
from diffusers.image_processor import VaeImageProcessor
@@ -2977,59 +2976,6 @@ class TaylorSeerCacheTesterMixin:
)
class MagCacheTesterMixin:
mag_cache_config = MagCacheConfig(
threshold=0.06,
max_skip_steps=3,
retention_ratio=0.2,
num_inference_steps=50,
mag_ratios=torch.ones(50),
)
def test_mag_cache_inference(self, expected_atol: float = 0.1):
device = "cpu"
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
# Match the config steps
inputs["num_inference_steps"] = 50
return pipe(**inputs)[0]
# 1. Run inference without MagCache (Baseline)
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# 2. Run inference with MagCache ENABLED
pipe = create_pipe()
pipe.transformer.enable_cache(self.mag_cache_config)
output = run_forward(pipe).flatten()
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
# 3. Run inference with MagCache DISABLED
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
"MagCache outputs should not differ too much from baseline."
)
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
"Outputs after disabling cache should match original inference exactly."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.

View File

@@ -1,396 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import unittest
import numpy as np
import torch
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
ZImageInpaintPipeline,
ZImageTransformer2DModel,
)
from diffusers.utils.testing_utils import floats_tensor
from ...testing_utils import torch_device
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin, to_np
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
# Cannot use enable_full_determinism() which sets it to True
# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends, "cuda"):
torch.backends.cuda.matmul.allow_tf32 = False
class ZImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = ZImageInpaintPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
image_params = frozenset(["image", "mask_image"])
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"strength",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def setUp(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def tearDown(self):
super().tearDown()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
def get_dummy_components(self):
torch.manual_seed(0)
transformer = ZImageTransformer2DModel(
all_patch_size=(2,),
all_f_patch_size=(1,),
in_channels=16,
dim=32,
n_layers=2,
n_refiner_layers=1,
n_heads=2,
n_kv_heads=2,
norm_eps=1e-5,
qk_norm=True,
cap_feat_dim=16,
rope_theta=256.0,
t_scale=1000.0,
axes_dims=[8, 4, 4],
axes_lens=[256, 32, 32],
)
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty` which contains
# uninitialized memory. Set them to known values for deterministic test behavior.
with torch.no_grad():
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
torch.manual_seed(0)
vae = AutoencoderKL(
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
block_out_channels=[32, 64],
layers_per_block=1,
latent_channels=16,
norm_num_groups=32,
sample_size=32,
scaling_factor=0.3611,
shift_factor=0.1159,
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
config = Qwen3Config(
hidden_size=16,
intermediate_size=16,
num_hidden_layers=2,
num_attention_heads=2,
num_key_value_heads=2,
vocab_size=151936,
max_position_embeddings=512,
)
text_encoder = Qwen3Model(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
import random
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
# Create mask: 1 = inpaint region, 0 = preserve region
mask_image = torch.zeros((1, 1, 32, 32), device=device)
mask_image[:, :, 8:24, 8:24] = 1.0 # Inpaint center region
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"image": image,
"mask_image": mask_image,
"strength": 1.0,
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"cfg_normalization": False,
"cfg_truncation": 1.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "np",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (32, 32, 3))
def test_inference_batch_single_identical(self):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
def test_num_images_per_prompt(self):
import inspect
sig = inspect.signature(self.pipeline_class.__call__)
if "num_images_per_prompt" not in sig.parameters:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
batch_sizes = [1, 2]
num_images_per_prompts = [1, 2]
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
inputs = self.get_dummy_inputs(torch_device)
for key in inputs.keys():
if key in self.batch_params:
inputs[key] = batch_size * [inputs[key]]
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
assert images.shape[0] == batch_size * num_images_per_prompt
del pipe
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.7):
import random
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
# Generate a larger image for the input
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
# Generate a larger mask for the input
mask = torch.zeros((1, 1, 128, 128), device="cpu")
mask[:, :, 32:96, 32:96] = 1.0
inputs["mask_image"] = mask
output_without_tiling = pipe(**inputs)[0]
# With tiling (standard AutoencoderKL doesn't accept parameters)
pipe.vae.enable_tiling()
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
inputs["mask_image"] = mask
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-3):
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
# Inpainting mask blending adds additional numerical variance
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
def test_group_offloading_inference(self):
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
def test_save_load_float16(self, expected_max_diff=1e-2):
# Z-Image does not support FP16 due to complex64 RoPE embeddings
self.skipTest("Z-Image does not support FP16 inference")
def test_float16_inference(self, expected_max_diff=5e-2):
# Z-Image does not support FP16 due to complex64 RoPE embeddings
self.skipTest("Z-Image does not support FP16 inference")
def test_strength_parameter(self):
"""Test that strength parameter affects the output correctly."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# Test with different strength values
inputs_low_strength = self.get_dummy_inputs(device)
inputs_low_strength["strength"] = 0.2
inputs_high_strength = self.get_dummy_inputs(device)
inputs_high_strength["strength"] = 0.8
# Both should complete without errors
output_low = pipe(**inputs_low_strength).images[0]
output_high = pipe(**inputs_high_strength).images[0]
# Outputs should be different (different amount of transformation)
self.assertFalse(np.allclose(output_low, output_high, atol=1e-3))
def test_invalid_strength(self):
"""Test that invalid strength values raise appropriate errors."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
inputs = self.get_dummy_inputs(device)
# Test strength < 0
inputs["strength"] = -0.1
with self.assertRaises(ValueError):
pipe(**inputs)
# Test strength > 1
inputs["strength"] = 1.5
with self.assertRaises(ValueError):
pipe(**inputs)
def test_mask_inpainting(self):
"""Test that the mask properly controls which regions are inpainted."""
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
# Generate with full mask (inpaint everything)
inputs_full = self.get_dummy_inputs(device)
inputs_full["mask_image"] = torch.ones((1, 1, 32, 32), device=device)
# Generate with no mask (preserve everything)
inputs_none = self.get_dummy_inputs(device)
inputs_none["mask_image"] = torch.zeros((1, 1, 32, 32), device=device)
# Both should complete without errors
output_full = pipe(**inputs_full).images[0]
output_none = pipe(**inputs_none).images[0]
# Outputs should be different (full inpaint vs preserve)
self.assertFalse(np.allclose(output_full, output_none, atol=1e-3))