mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-07 03:15:16 +08:00
Compare commits
1 Commits
component-
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8da128067c |
@@ -53,41 +53,6 @@ image = pipe(
|
||||
image.save("zimage_img2img.png")
|
||||
```
|
||||
|
||||
## Inpainting
|
||||
|
||||
Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import ZImageInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
# Create a mask (white = inpaint, black = preserve)
|
||||
mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
mask_image = Image.fromarray(mask)
|
||||
|
||||
prompt = "A beautiful lake with mountains in the background"
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=1.0,
|
||||
num_inference_steps=9,
|
||||
guidance_scale=0.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).images[0]
|
||||
image.save("zimage_inpaint.png")
|
||||
```
|
||||
|
||||
## ZImagePipeline
|
||||
|
||||
[[autodoc]] ZImagePipeline
|
||||
@@ -99,9 +64,3 @@ image.save("zimage_inpaint.png")
|
||||
[[autodoc]] ZImageImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ZImageInpaintPipeline
|
||||
|
||||
[[autodoc]] ZImageInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -12,85 +12,179 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ComponentsManager
|
||||
|
||||
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), and supports offloading.
|
||||
The [`ComponentsManager`] is a model registry and management system for Modular Diffusers. It adds and tracks models, stores useful metadata (model size, device placement, adapters), prevents duplicate model instances, and supports offloading.
|
||||
|
||||
This guide will show you how to use [`ComponentsManager`] to manage components and device memory.
|
||||
|
||||
## Connect to a pipeline
|
||||
## Add a component
|
||||
|
||||
Create a [`ComponentsManager`] and pass it to a [`ModularPipeline`] with either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
|
||||
The [`ComponentsManager`] should be created alongside a [`ModularPipeline`] in either [`~ModularPipeline.from_pretrained`] or [`~ModularPipelineBlocks.init_pipeline`].
|
||||
|
||||
> [!TIP]
|
||||
> The `collection` parameter is optional but makes it easier to organize and manage components.
|
||||
|
||||
<hfoptions id="create">
|
||||
<hfoption id="from_pretrained">
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
import torch
|
||||
|
||||
manager = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", components_manager=manager)
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
comp = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="init_pipeline">
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipelineBlocks, ComponentsManager
|
||||
import torch
|
||||
manager = ComponentsManager()
|
||||
blocks = ModularPipelineBlocks.from_pretrained("diffusers/Florence2-image-Annotator", trust_remote_code=True)
|
||||
pipe= blocks.init_pipeline(components_manager=manager)
|
||||
pipe.load_components(torch_dtype=torch.bfloat16)
|
||||
from diffusers import ComponentsManager
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
components = ComponentsManager()
|
||||
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Components loaded by the pipeline are automatically registered in the manager. You can inspect them right away.
|
||||
|
||||
## Inspect components
|
||||
|
||||
Print the [`ComponentsManager`] to see all registered components, including their class, device placement, dtype, memory size, and load ID.
|
||||
|
||||
The output below corresponds to the `from_pretrained` example above
|
||||
Components are only loaded and registered when using [`~ModularPipeline.load_components`] or [`~ModularPipeline.load_components`]. The example below uses [`~ModularPipeline.load_components`] to create a second pipeline that reuses all the components from the first one, and assigns it to a different collection
|
||||
|
||||
```py
|
||||
Components:
|
||||
=============================================================================================================================
|
||||
Models:
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_140458257514752 | Qwen3Model | cpu | torch.bfloat16 | 7.49 | Tongyi-MAI/Z-Image-Turbo|text_encoder|null|null
|
||||
vae_140458257515376 | AutoencoderKL | cpu | torch.bfloat16 | 0.16 | Tongyi-MAI/Z-Image-Turbo|vae|null|null
|
||||
transformer_140458257515616 | ZImageTransformer2DModel | cpu | torch.bfloat16 | 11.46 | Tongyi-MAI/Z-Image-Turbo|transformer|null|null
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Other Components:
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
ID | Class | Collection
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
scheduler_140461023555264 | FlowMatchEulerDiscreteScheduler | N/A
|
||||
tokenizer_140458256346432 | Qwen2Tokenizer | N/A
|
||||
-----------------------------------------------------------------------------------------------------------------------------
|
||||
pipe.load_components()
|
||||
pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
|
||||
```
|
||||
|
||||
The table shows models (with device, dtype, and memory info) separately from other components like schedulers and tokenizers. If any models have LoRA adapters, IP-Adapters, or quantization applied, that information is displayed in an additional section at the bottom.
|
||||
Use the [`~ModularPipeline.null_component_names`] property to identify any components that need to be loaded, retrieve them with [`~ComponentsManager.get_components_by_names`], and then call [`~ModularPipeline.update_components`] to add the missing components.
|
||||
|
||||
```py
|
||||
pipe2.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
|
||||
|
||||
comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
|
||||
pipe2.update_components(**comp_dict)
|
||||
```
|
||||
|
||||
To add individual components, use the [`~ComponentsManager.add`] method. This registers a component with a unique id.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
component_id = comp.add("text_encoder", text_encoder)
|
||||
comp
|
||||
```
|
||||
|
||||
Use [`~ComponentsManager.remove`] to remove a component using their id.
|
||||
|
||||
```py
|
||||
comp.remove("text_encoder_139917733042864")
|
||||
```
|
||||
|
||||
## Retrieve a component
|
||||
|
||||
The [`ComponentsManager`] provides several methods to retrieve registered components.
|
||||
|
||||
### get_one
|
||||
|
||||
The [`~ComponentsManager.get_one`] method returns a single component and supports pattern matching for the `name` parameter. If multiple components match, [`~ComponentsManager.get_one`] returns an error.
|
||||
|
||||
| Pattern | Example | Description |
|
||||
|-------------|----------------------------------|-------------------------------------------|
|
||||
| exact | `comp.get_one(name="unet")` | exact name match |
|
||||
| wildcard | `comp.get_one(name="unet*")` | names starting with "unet" |
|
||||
| exclusion | `comp.get_one(name="!unet")` | exclude components named "unet" |
|
||||
| or | `comp.get_one(name="unet|vae")` | name is "unet" or "vae" |
|
||||
|
||||
[`~ComponentsManager.get_one`] also filters components by the `collection` argument or `load_id` argument.
|
||||
|
||||
```py
|
||||
comp.get_one(name="unet", collection="sdxl")
|
||||
```
|
||||
|
||||
### get_components_by_names
|
||||
|
||||
The [`~ComponentsManager.get_components_by_names`] method accepts a list of names and returns a dictionary mapping names to components. This is especially useful with [`ModularPipeline`] since they provide lists of required component names and the returned dictionary can be passed directly to [`~ModularPipeline.update_components`].
|
||||
|
||||
```py
|
||||
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
|
||||
{"text_encoder": component1, "unet": component2, "vae": component3}
|
||||
```
|
||||
|
||||
## Duplicate detection
|
||||
|
||||
It is recommended to load model components with [`ComponentSpec`] to assign components with a unique id that encodes their loading parameters. This allows [`ComponentsManager`] to automatically detect and prevent duplicate model instances even when different objects represent the same underlying checkpoint.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ComponentsManager
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
comp = ComponentsManager()
|
||||
|
||||
# Create ComponentSpec for the first text encoder
|
||||
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from the same repo/subfolder)
|
||||
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
|
||||
|
||||
# Load and add both components - the manager will detect they're the same model
|
||||
comp.add("text_encoder", spec.load())
|
||||
comp.add("text_encoder_duplicated", spec_duplicated.load())
|
||||
```
|
||||
|
||||
This returns a warning with instructions for removing the duplicate.
|
||||
|
||||
```py
|
||||
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
|
||||
'text_encoder_duplicated_139917580682672'
|
||||
```
|
||||
|
||||
You could also add a component without using [`ComponentSpec`] and duplicate detection still works in most cases even if you're adding the same component under a different name.
|
||||
|
||||
However, [`ComponentManager`] can't detect duplicates when you load the same component into different objects. In this case, you should load a model with [`ComponentSpec`].
|
||||
|
||||
```py
|
||||
text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
comp.add("text_encoder", text_encoder_2)
|
||||
'text_encoder_139917732983664'
|
||||
```
|
||||
|
||||
## Collections
|
||||
|
||||
Collections are labels assigned to components for better organization and management. Add a component to a collection with the `collection` argument in [`~ComponentsManager.add`].
|
||||
|
||||
Only one component per name is allowed in each collection. Adding a second component with the same name automatically removes the first component.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ComponentsManager
|
||||
|
||||
comp = ComponentsManager()
|
||||
# Create ComponentSpec for the first UNet
|
||||
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a different UNet
|
||||
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
|
||||
|
||||
# Add both UNets to the same collection - the second one will replace the first
|
||||
comp.add("unet", spec.load(), collection="sdxl")
|
||||
comp.add("unet", spec2.load(), collection="sdxl")
|
||||
```
|
||||
|
||||
This makes it convenient to work with node-based systems because you can:
|
||||
|
||||
- Mark all models as loaded from one node with the `collection` label.
|
||||
- Automatically replace models when new checkpoints are loaded under the same name.
|
||||
- Batch delete all models in a collection when a node is removed.
|
||||
|
||||
## Offloading
|
||||
|
||||
The [`~ComponentsManager.enable_auto_cpu_offload`] method is a global offloading strategy that works across all models regardless of which pipeline is using them. Once enabled, you don't need to worry about device placement if you add or remove components.
|
||||
|
||||
```py
|
||||
manager.enable_auto_cpu_offload(device="cuda")
|
||||
comp.enable_auto_cpu_offload(device="cuda")
|
||||
```
|
||||
|
||||
All models begin on the CPU and [`ComponentsManager`] moves them to the appropriate device right before they're needed, and moves other models back to the CPU when GPU memory is low.
|
||||
|
||||
To disable offloading, call [~ComponentsManager.disable_auto_cpu_offload].
|
||||
|
||||
```py
|
||||
manager.disable_auto_cpu_offload()
|
||||
```
|
||||
You can set your own rules for which models to offload first.
|
||||
|
||||
@@ -111,57 +111,3 @@ config = TaylorSeerCacheConfig(
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## MagCache
|
||||
|
||||
[MagCache](https://github.com/Zehong-Ma/MagCache) accelerates inference by skipping transformer blocks based on the magnitude of the residual update. It observes that the magnitude of updates (Output - Input) decays predictably over the diffusion process. By accumulating an "error budget" based on pre-computed magnitude ratios, it dynamically decides when to skip computation and reuse the previous residual.
|
||||
|
||||
MagCache relies on **Magnitude Ratios** (`mag_ratios`), which describe this decay curve. These ratios are specific to the model checkpoint and scheduler.
|
||||
|
||||
### Usage
|
||||
|
||||
To use MagCache, you typically follow a two-step process: **Calibration** and **Inference**.
|
||||
|
||||
1. **Calibration**: Run inference once with `calibrate=True`. The hook will measure the residual magnitudes and print the calculated ratios to the console.
|
||||
2. **Inference**: Pass these ratios to `MagCacheConfig` to enable acceleration.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import FluxPipeline, MagCacheConfig
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
|
||||
# 1. Calibration Step
|
||||
# Run full inference to measure model behavior.
|
||||
calib_config = MagCacheConfig(calibrate=True, num_inference_steps=4)
|
||||
pipe.transformer.enable_cache(calib_config)
|
||||
|
||||
# Run a prompt to trigger calibration
|
||||
pipe("A cat playing chess", num_inference_steps=4)
|
||||
# Logs will print something like: "MagCache Calibration Results: [1.0, 1.37, 0.97, 0.87]"
|
||||
|
||||
# 2. Inference Step
|
||||
# Apply the specific ratios obtained from calibration for optimized speed.
|
||||
# Note: For Flux models, you can also import defaults:
|
||||
# from diffusers.hooks.mag_cache import FLUX_MAG_RATIOS
|
||||
mag_config = MagCacheConfig(
|
||||
mag_ratios=[1.0, 1.37, 0.97, 0.87],
|
||||
num_inference_steps=4
|
||||
)
|
||||
|
||||
pipe.transformer.enable_cache(mag_config)
|
||||
|
||||
image = pipe("A cat playing chess", num_inference_steps=4).images[0]
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `mag_ratios` represent the model's intrinsic magnitude decay curve. Ratios calibrated for a high number of steps (e.g., 50) can be reused for lower step counts (e.g., 20). The implementation uses interpolation to map the curve to the current number of inference steps.
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance sequentially (like Kandinsky 5.0), the calibration log might print two arrays: one for the Conditional pass and one for the Unconditional pass. In most cases, you should use the first array (Conditional).
|
||||
|
||||
> [!TIP]
|
||||
> For pipelines that run Classifier-Free Guidance in a **batched** manner (like SDXL or Flux), the `hidden_states` processed by the model contain both conditional and unconditional branches concatenated together. The calibration process automatically accounts for this, producing a single array of ratios that represents the joint behavior. You can use this resulting array directly without modification.
|
||||
|
||||
@@ -168,14 +168,12 @@ else:
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"MagCacheConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"TaylorSeerCacheConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_mag_cache",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
"apply_taylorseer_cache",
|
||||
]
|
||||
@@ -696,7 +694,6 @@ else:
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
@@ -935,14 +932,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -1429,7 +1424,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ if is_torch_available():
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .mag_cache import MagCacheConfig, apply_mag_cache
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
from .taylorseer_cache import TaylorSeerCacheConfig, apply_taylorseer_cache
|
||||
|
||||
@@ -23,13 +23,7 @@ from ..models.attention_processor import Attention, MochiAttention
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = (
|
||||
"blocks",
|
||||
"transformer_blocks",
|
||||
"single_transformer_blocks",
|
||||
"layers",
|
||||
"visual_transformer_blocks",
|
||||
)
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ class AttentionProcessorMetadata:
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
hidden_states_argument_name: str = "hidden_states"
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
@@ -170,7 +169,7 @@ def _register_attention_processors_metadata():
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock, JointTransformerBlock
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_bria import BriaTransformerBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
@@ -185,7 +184,6 @@ def _register_transformer_blocks_metadata():
|
||||
HunyuanImageSingleTransformerBlock,
|
||||
HunyuanImageTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_kandinsky import Kandinsky5TransformerDecoderBlock
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
|
||||
@@ -333,24 +331,6 @@ def _register_transformer_blocks_metadata():
|
||||
),
|
||||
)
|
||||
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=JointTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Kandinsky 5.0 (Kandinsky5TransformerDecoderBlock)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=Kandinsky5TransformerDecoderBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
hidden_states_argument_name="visual_embed",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
|
||||
@@ -1,468 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook"
|
||||
_MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook"
|
||||
|
||||
# Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience.
|
||||
# Users must explicitly pass these to the config if using Flux.
|
||||
# Reference: https://github.com/Zehong-Ma/MagCache
|
||||
FLUX_MAG_RATIOS = torch.tensor(
|
||||
[1.0]
|
||||
+ [
|
||||
1.21094,
|
||||
1.11719,
|
||||
1.07812,
|
||||
1.0625,
|
||||
1.03906,
|
||||
1.03125,
|
||||
1.03906,
|
||||
1.02344,
|
||||
1.03125,
|
||||
1.02344,
|
||||
0.98047,
|
||||
1.01562,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.00781,
|
||||
1.0,
|
||||
1.0,
|
||||
0.99609,
|
||||
0.99609,
|
||||
0.98047,
|
||||
0.98828,
|
||||
0.96484,
|
||||
0.95703,
|
||||
0.93359,
|
||||
0.89062,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor:
|
||||
"""
|
||||
Interpolate the source array to the target length using nearest neighbor interpolation.
|
||||
"""
|
||||
src_length = len(src_array)
|
||||
if target_length == 1:
|
||||
return src_array[-1:]
|
||||
|
||||
scale = (src_length - 1) / (target_length - 1)
|
||||
grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32)
|
||||
mapped_indices = torch.round(grid * scale).long()
|
||||
return src_array[mapped_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MagCacheConfig:
|
||||
r"""
|
||||
Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.06`):
|
||||
The threshold for the accumulated error. If the accumulated error is below this threshold, the block
|
||||
computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade
|
||||
quality.
|
||||
max_skip_steps (`int`, defaults to `3`):
|
||||
The maximum number of consecutive steps that can be skipped (K in the paper).
|
||||
retention_ratio (`float`, defaults to `0.2`):
|
||||
The fraction of initial steps during which skipping is disabled to ensure stability. For example, if
|
||||
`num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped.
|
||||
num_inference_steps (`int`, defaults to `28`):
|
||||
The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly.
|
||||
mag_ratios (`torch.Tensor`, *optional*):
|
||||
The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must
|
||||
set `calibrate=True` to calculate them for your specific model. For Flux models, you can use
|
||||
`diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`.
|
||||
calibrate (`bool`, defaults to `False`):
|
||||
If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the
|
||||
magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new
|
||||
models or schedulers.
|
||||
"""
|
||||
|
||||
threshold: float = 0.06
|
||||
max_skip_steps: int = 3
|
||||
retention_ratio: float = 0.2
|
||||
num_inference_steps: int = 28
|
||||
mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None
|
||||
calibrate: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# User MUST provide ratios OR enable calibration.
|
||||
if self.mag_ratios is None and not self.calibrate:
|
||||
raise ValueError(
|
||||
" `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n"
|
||||
"To get them for your model:\n"
|
||||
"1. Initialize `MagCacheConfig(calibrate=True, ...)`\n"
|
||||
"2. Run inference on your model once.\n"
|
||||
"3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n"
|
||||
"For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`."
|
||||
)
|
||||
|
||||
if not self.calibrate and self.mag_ratios is not None:
|
||||
if not torch.is_tensor(self.mag_ratios):
|
||||
self.mag_ratios = torch.tensor(self.mag_ratios)
|
||||
|
||||
if len(self.mag_ratios) != self.num_inference_steps:
|
||||
logger.debug(
|
||||
f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}"
|
||||
)
|
||||
self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps)
|
||||
|
||||
|
||||
class MagCacheState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# Cache for the residual (output - input) from the *previous* timestep
|
||||
self.previous_residual: torch.Tensor = None
|
||||
|
||||
# State inputs/outputs for the current forward pass
|
||||
self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
# MagCache accumulators
|
||||
self.accumulated_ratio: float = 1.0
|
||||
self.accumulated_err: float = 0.0
|
||||
self.accumulated_steps: int = 0
|
||||
|
||||
# Current step counter (timestep index)
|
||||
self.step_index: int = 0
|
||||
|
||||
# Calibration storage
|
||||
self.calibration_ratios: List[float] = []
|
||||
|
||||
def reset(self):
|
||||
self.previous_residual = None
|
||||
self.should_compute = True
|
||||
self.accumulated_ratio = 1.0
|
||||
self.accumulated_err = 0.0
|
||||
self.accumulated_steps = 0
|
||||
self.step_index = 0
|
||||
self.calibration_ratios = []
|
||||
|
||||
|
||||
class MagCacheHeadHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, config: MagCacheConfig):
|
||||
self.state_manager = state_manager
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
state.head_block_input = hidden_states
|
||||
|
||||
should_compute = True
|
||||
|
||||
if self.config.calibrate:
|
||||
# Never skip during calibration
|
||||
should_compute = True
|
||||
else:
|
||||
# MagCache Logic
|
||||
current_step = state.step_index
|
||||
if current_step >= len(self.config.mag_ratios):
|
||||
current_scale = 1.0
|
||||
else:
|
||||
current_scale = self.config.mag_ratios[current_step]
|
||||
|
||||
retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5)
|
||||
|
||||
if current_step >= retention_step:
|
||||
state.accumulated_ratio *= current_scale
|
||||
state.accumulated_steps += 1
|
||||
state.accumulated_err += abs(1.0 - state.accumulated_ratio)
|
||||
|
||||
if (
|
||||
state.previous_residual is not None
|
||||
and state.accumulated_err <= self.config.threshold
|
||||
and state.accumulated_steps <= self.config.max_skip_steps
|
||||
):
|
||||
should_compute = False
|
||||
else:
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
|
||||
state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
logger.debug(f"MagCache: Skipping step {state.step_index}")
|
||||
# Apply MagCache: Output = Input + Previous Residual
|
||||
|
||||
output = hidden_states
|
||||
res = state.previous_residual
|
||||
|
||||
if res.device != output.device:
|
||||
res = res.to(output.device)
|
||||
|
||||
# Attempt to apply residual handling shape mismatches (e.g., text+image vs image only)
|
||||
if res.shape == output.shape:
|
||||
output = output + res
|
||||
elif (
|
||||
output.ndim == 3
|
||||
and res.ndim == 3
|
||||
and output.shape[0] == res.shape[0]
|
||||
and output.shape[2] == res.shape[2]
|
||||
):
|
||||
# Assuming concatenation where image part is at the end (standard in Flux/SD3)
|
||||
diff = output.shape[1] - res.shape[1]
|
||||
if diff > 0:
|
||||
output = output.clone()
|
||||
output[:, diff:, :] = output[:, diff:, :] + res
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. "
|
||||
"Cannot apply residual safely. Returning input without residual."
|
||||
)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = output
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
else:
|
||||
return output
|
||||
|
||||
else:
|
||||
# Compute original forward
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
|
||||
class MagCacheBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self.config = config
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.state_manager._current_context is None:
|
||||
self.state_manager.set_context("inference")
|
||||
state: MagCacheState = self.state_manager.get_state()
|
||||
|
||||
if not state.should_compute:
|
||||
arg_name = self._metadata.hidden_states_argument_name
|
||||
hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Still need to advance step index even if we skip
|
||||
self._advance_step(state)
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
max_idx = max(
|
||||
self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index
|
||||
)
|
||||
ret_list = [None] * (max_idx + 1)
|
||||
ret_list[self._metadata.return_hidden_states_index] = hidden_states
|
||||
ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return tuple(ret_list)
|
||||
|
||||
return hidden_states
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
|
||||
if self.is_tail:
|
||||
# Calculate residual for next steps
|
||||
if isinstance(output, tuple):
|
||||
out_hidden = output[self._metadata.return_hidden_states_index]
|
||||
else:
|
||||
out_hidden = output
|
||||
|
||||
in_hidden = state.head_block_input
|
||||
|
||||
if in_hidden is None:
|
||||
return output
|
||||
|
||||
# Determine residual
|
||||
if out_hidden.shape == in_hidden.shape:
|
||||
residual = out_hidden - in_hidden
|
||||
elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]:
|
||||
diff = in_hidden.shape[1] - out_hidden.shape[1]
|
||||
if diff == 0:
|
||||
residual = out_hidden - in_hidden
|
||||
else:
|
||||
residual = out_hidden - in_hidden # Fallback to matching tail
|
||||
else:
|
||||
# Fallback for completely mismatched shapes
|
||||
residual = out_hidden
|
||||
|
||||
if self.config.calibrate:
|
||||
self._perform_calibration_step(state, residual)
|
||||
|
||||
state.previous_residual = residual
|
||||
self._advance_step(state)
|
||||
|
||||
return output
|
||||
|
||||
def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor):
|
||||
if state.previous_residual is None:
|
||||
# First step has no previous residual to compare against.
|
||||
# log 1.0 as a neutral starting point.
|
||||
ratio = 1.0
|
||||
else:
|
||||
# MagCache Calibration Formula: mean(norm(curr) / norm(prev))
|
||||
# norm(dim=-1) gives magnitude of each token vector
|
||||
curr_norm = torch.linalg.norm(current_residual.float(), dim=-1)
|
||||
prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1)
|
||||
|
||||
# Avoid division by zero
|
||||
ratio = (curr_norm / (prev_norm + 1e-8)).mean().item()
|
||||
|
||||
state.calibration_ratios.append(ratio)
|
||||
|
||||
def _advance_step(self, state: MagCacheState):
|
||||
state.step_index += 1
|
||||
if state.step_index >= self.config.num_inference_steps:
|
||||
# End of inference loop
|
||||
if self.config.calibrate:
|
||||
print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):")
|
||||
print(f"{state.calibration_ratios}\n")
|
||||
logger.info(f"MagCache Calibration Results: {state.calibration_ratios}")
|
||||
|
||||
# Reset state
|
||||
state.step_index = 0
|
||||
state.accumulated_ratio = 1.0
|
||||
state.accumulated_steps = 0
|
||||
state.accumulated_err = 0.0
|
||||
state.previous_residual = None
|
||||
state.calibration_ratios = []
|
||||
|
||||
|
||||
def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None:
|
||||
"""
|
||||
Applies MagCache to a given module (typically a Transformer).
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to apply MagCache to.
|
||||
config (`MagCacheConfig`):
|
||||
The configuration for MagCache.
|
||||
"""
|
||||
# Initialize registry on the root module so the Pipeline can set context.
|
||||
HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
state_manager = StateManager(MagCacheState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
if not remaining_blocks:
|
||||
logger.warning("MagCache: No transformer blocks found to apply hooks.")
|
||||
return
|
||||
|
||||
# Handle single-block models
|
||||
if len(remaining_blocks) == 1:
|
||||
name, block = remaining_blocks[0]
|
||||
logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'")
|
||||
_apply_mag_cache_block_hook(block, state_manager, config, is_tail=True)
|
||||
_apply_mag_cache_head_hook(block, state_manager, config)
|
||||
return
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.info(f"MagCache: Applying Head Hook to {head_block_name}")
|
||||
_apply_mag_cache_head_hook(head_block, state_manager, config)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
_apply_mag_cache_block_hook(block, state_manager, config)
|
||||
|
||||
logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}")
|
||||
_apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True)
|
||||
|
||||
|
||||
def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application (e.g. switching modes)
|
||||
if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheHeadHook(state_manager, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_mag_cache_block_hook(
|
||||
block: torch.nn.Module,
|
||||
state_manager: StateManager,
|
||||
config: MagCacheConfig,
|
||||
is_tail: bool = False,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
|
||||
# Automatically remove existing hook to allow re-application
|
||||
if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None:
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK)
|
||||
|
||||
hook = MagCacheBlockHook(state_manager, is_tail, config)
|
||||
registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK)
|
||||
@@ -68,12 +68,10 @@ class CacheMixin:
|
||||
from ..hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_mag_cache,
|
||||
apply_pyramid_attention_broadcast,
|
||||
apply_taylorseer_cache,
|
||||
)
|
||||
@@ -87,8 +85,6 @@ class CacheMixin:
|
||||
apply_faster_cache(self, config)
|
||||
elif isinstance(config, FirstBlockCacheConfig):
|
||||
apply_first_block_cache(self, config)
|
||||
elif isinstance(config, MagCacheConfig):
|
||||
apply_mag_cache(self, config)
|
||||
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
||||
apply_pyramid_attention_broadcast(self, config)
|
||||
elif isinstance(config, TaylorSeerCacheConfig):
|
||||
@@ -103,13 +99,11 @@ class CacheMixin:
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
MagCacheConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
TaylorSeerCacheConfig,
|
||||
)
|
||||
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
||||
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
||||
from ..hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
|
||||
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
||||
from ..hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
|
||||
|
||||
@@ -124,9 +118,6 @@ class CacheMixin:
|
||||
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
||||
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, MagCacheConfig):
|
||||
registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK, recurse=True)
|
||||
registry.remove_hook(_MAG_CACHE_BLOCK_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
||||
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
||||
elif isinstance(self._cache_config, TaylorSeerCacheConfig):
|
||||
|
||||
@@ -125,9 +125,9 @@ class BriaFiboAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -130,9 +130,9 @@ class FluxAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -561,11 +561,11 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
||||
|
||||
# Apply output projections
|
||||
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
|
||||
img_attn_output = attn.to_out[0](img_attn_output)
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
||||
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
@@ -410,12 +410,11 @@ else:
|
||||
"Kandinsky5I2IPipeline",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -871,7 +870,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -127,7 +127,6 @@ from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
@@ -236,7 +235,6 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
("qwenimage", QwenImageInpaintPipeline),
|
||||
("qwenimage-edit", QwenImageEditInpaintPipeline),
|
||||
("z-image", ZImageInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ else:
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
@@ -43,7 +42,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
from .pipeline_z_image_inpaint import ZImageInpaintPipeline
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -1,932 +0,0 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageInpaintPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
>>> # Create a mask (white = inpaint, black = preserve)
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
>>> mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
>>> mask_image = Image.fromarray(mask)
|
||||
|
||||
>>> prompt = "A beautiful lake with mountains in the background"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... image=init_image,
|
||||
... mask_image=mask_image,
|
||||
... strength=1.0,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage_inpaint.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The ZImage pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`PreTrainedModel`]):
|
||||
A text encoder model to encode text prompts.
|
||||
tokenizer ([`AutoTokenizer`]):
|
||||
A tokenizer to tokenize text prompts.
|
||||
transformer ([`ZImageTransformer2DModel`]):
|
||||
A ZImage transformer model to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
"""Prepare mask and masked image latents for inpainting.
|
||||
|
||||
Args:
|
||||
mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region.
|
||||
masked_image: Original image with masked regions zeroed out.
|
||||
batch_size: Number of images to generate.
|
||||
height: Output image height.
|
||||
width: Output image width.
|
||||
dtype: Data type for the tensors.
|
||||
device: Device to place tensors on.
|
||||
generator: Random generator for reproducibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, masked_image_latents) prepared for the denoising loop.
|
||||
"""
|
||||
# Calculate latent dimensions
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
# Resize mask to latent dimensions
|
||||
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest")
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode masked image to latents
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i])
|
||||
for i in range(masked_image.shape[0])
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
# Apply VAE scaling
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# Expand for batch size
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
"""Prepare latents for inpainting, returning noise and image_latents for blending.
|
||||
|
||||
Returns:
|
||||
Tuple of (latents, noise, image_latents) where:
|
||||
- latents: Noised image latents for denoising
|
||||
- noise: The noise tensor used for blending
|
||||
- image_latents: Clean image latents for blending
|
||||
"""
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
# Generate noise for blending even if latents are provided
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# Encode image for blending
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0)
|
||||
return latents.to(device=device, dtype=dtype), noise, image_latents
|
||||
|
||||
# Encode the input image
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != num_channels_latents:
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
# Handle batch size expansion
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
# Generate noise for both initial noising and later blending
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# Add noise using flow matching scale_noise
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
||||
|
||||
return latents, noise, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined for inpainting.")
|
||||
|
||||
if mask_image is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined for inpainting.")
|
||||
|
||||
if output_type not in ["latent", "pil", "np", "pt"]:
|
||||
raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
strength: float = 1.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
|
||||
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
|
||||
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the
|
||||
mask will be inpainted, black pixels (value 0) will be preserved from the original image.
|
||||
masked_image_latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image` in the masked region.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image. If not provided, uses the input image height.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image. If not provided, uses the input image width.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type=output_type,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Preprocess image and mask
|
||||
init_image = self.image_processor.preprocess(image)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# Get dimensions from the preprocessed image if not specified
|
||||
if height is None:
|
||||
height = init_image.shape[-2]
|
||||
if width is None:
|
||||
width = init_image.shape[-1]
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# Preprocess mask
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
# Calculate latent dimensions for image_seq_len
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
|
||||
# 6. Adjust timesteps based on strength
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1].repeat(actual_batch_size)
|
||||
|
||||
# 7. Prepare latents from image (returns noise and image_latents for blending)
|
||||
latents, noise, image_latents = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
actual_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 8. Prepare mask and masked image latents
|
||||
# Create masked image: preserve only unmasked regions (mask=0)
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
else:
|
||||
masked_image = None # Will use provided masked_image_latents
|
||||
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image if masked_image is not None else init_image,
|
||||
actual_batch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
# Inpainting blend: combine denoised latents with original image latents
|
||||
init_latents_proper = image_latents
|
||||
|
||||
# Re-scale original latents to current noise level for proper blending
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.scale_noise(
|
||||
init_latents_proper, torch.tensor([noise_timestep]), noise
|
||||
)
|
||||
|
||||
# Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original)
|
||||
latents = (1 - mask) * init_latents_proper + mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -79,8 +79,7 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
weight = dequantize_gguf_tensor(qweight)
|
||||
return x @ weight.T
|
||||
return x @ qweight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
|
||||
@@ -545,9 +545,7 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -867,9 +867,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -245,26 +245,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -272,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -308,12 +287,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -750,7 +724,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -764,7 +738,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -848,7 +822,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -858,10 +832,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -888,10 +860,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -922,7 +891,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -932,7 +901,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -1045,7 +1014,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -1055,10 +1024,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -1139,9 +1106,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
@@ -1251,10 +1216,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=torch.float32,
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
||||
|
||||
@@ -141,10 +141,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The flow shift factor. Valid only when `use_flow_sigmas=True`.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -167,15 +163,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@@ -184,32 +180,19 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -217,15 +200,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -244,12 +219,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -280,11 +250,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -416,7 +382,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -453,7 +419,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
@@ -475,7 +441,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -601,7 +567,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -615,7 +581,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -700,7 +666,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -710,10 +676,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -740,10 +704,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -775,7 +736,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -785,7 +746,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -899,7 +860,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
sample: torch.Tensor = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -909,10 +870,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -992,7 +951,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]):
|
||||
def _init_step_index(self, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
@@ -1016,7 +975,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
generator=None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -1068,10 +1027,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise
|
||||
@@ -1118,21 +1074,6 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the clean `original_samples` using the scheduler's equivalent function.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples to add noise to.
|
||||
noise (`torch.Tensor`):
|
||||
The noise tensor.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
@@ -1162,5 +1103,5 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self) -> int:
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -1120,9 +1120,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -662,9 +662,7 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1122,9 +1122,7 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1083,9 +1083,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -227,21 +227,6 @@ class LayerSkipConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MagCacheConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -299,10 +284,6 @@ def apply_layer_skip(*args, **kwargs):
|
||||
requires_backends(apply_layer_skip, ["torch"])
|
||||
|
||||
|
||||
def apply_mag_cache(*args, **kwargs):
|
||||
requires_backends(apply_mag_cache, ["torch"])
|
||||
|
||||
|
||||
def apply_pyramid_attention_broadcast(*args, **kwargs):
|
||||
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
|
||||
|
||||
|
||||
@@ -4112,21 +4112,6 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import MagCacheConfig, apply_mag_cache
|
||||
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DummyBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Output is double input
|
||||
# This ensures Residual = 2*Input - Input = Input
|
||||
return hidden_states * 2.0
|
||||
|
||||
|
||||
class DummyTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TupleOutputBlock(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
|
||||
# Returns a tuple
|
||||
return hidden_states * 2.0, encoder_hidden_states
|
||||
|
||||
|
||||
class TupleTransformer(ModelMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None):
|
||||
for block in self.transformer_blocks:
|
||||
# Emulate Flux-like behavior
|
||||
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
|
||||
hidden_states = output[0]
|
||||
encoder_hidden_states = output[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MagCacheTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# Register standard dummy block
|
||||
TransformerBlockRegistry.register(
|
||||
DummyBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
|
||||
)
|
||||
# Register tuple block (Flux style)
|
||||
TransformerBlockRegistry.register(
|
||||
TupleOutputBlock,
|
||||
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
|
||||
)
|
||||
|
||||
def _set_context(self, model, context_name):
|
||||
"""Helper to set context on all hooks in the model."""
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(context_name)
|
||||
|
||||
def _get_calibration_data(self, model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
|
||||
if hook:
|
||||
return hook.state_manager.get_state().calibration_ratios
|
||||
return []
|
||||
|
||||
def test_mag_cache_validation(self):
|
||||
"""Test that missing mag_ratios raises ValueError."""
|
||||
with self.assertRaises(ValueError):
|
||||
MagCacheConfig(num_inference_steps=10, calibrate=False)
|
||||
|
||||
def test_mag_cache_skipping_logic(self):
|
||||
"""
|
||||
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
|
||||
"""
|
||||
model = DummyTransformer()
|
||||
|
||||
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=0.0, # Enable immediate skipping
|
||||
max_skip_steps=5,
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
|
||||
# HeadInput=10. Output=40. Residual=30.
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
output_t0 = model(input_t0)
|
||||
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
|
||||
|
||||
# Step 1: Input 11.0.
|
||||
# If Skipped: Output = Input(11) + Residual(30) = 41.0
|
||||
# If Computed: Output = 11 * 4 = 44.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_retention(self):
|
||||
"""Test that retention_ratio prevents skipping even if error is low."""
|
||||
model = DummyTransformer()
|
||||
# Ratios that imply 0 error, so it *would* skip if retention allowed it
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0,
|
||||
num_inference_steps=2,
|
||||
retention_ratio=1.0, # Force retention for ALL steps
|
||||
mag_ratios=ratios,
|
||||
)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
output_t1 = model(input_t1)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
|
||||
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
|
||||
)
|
||||
|
||||
def test_mag_cache_tuple_outputs(self):
|
||||
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
|
||||
model = TupleTransformer()
|
||||
ratios = np.array([1.0, 1.0])
|
||||
|
||||
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
|
||||
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
|
||||
# Residual = 10.0
|
||||
input_t0 = torch.tensor([[[10.0]]])
|
||||
enc_t0 = torch.tensor([[[1.0]]])
|
||||
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
|
||||
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
|
||||
|
||||
# Step 1: Skip. Input 11.0.
|
||||
# Skipped Output = 11 + 10 = 21.0
|
||||
input_t1 = torch.tensor([[[11.0]]])
|
||||
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
|
||||
)
|
||||
|
||||
def test_mag_cache_reset(self):
|
||||
"""Test that state resets correctly after num_inference_steps."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(
|
||||
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
|
||||
)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
input_t = torch.ones(1, 1, 1)
|
||||
|
||||
model(input_t) # Step 0
|
||||
model(input_t) # Step 1 (Skipped)
|
||||
|
||||
# Step 2 (Reset -> Step 0) -> Should Compute
|
||||
# Input 2.0 -> Output 8.0
|
||||
input_t2 = torch.tensor([[[2.0]]])
|
||||
output_t2 = model(input_t2)
|
||||
|
||||
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
|
||||
|
||||
def test_mag_cache_calibration(self):
|
||||
"""Test that calibration mode records ratios."""
|
||||
model = DummyTransformer()
|
||||
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
|
||||
apply_mag_cache(model, config)
|
||||
self._set_context(model, "test_context")
|
||||
|
||||
# Step 0
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# Ratio 0 is placeholder 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Check intermediate state
|
||||
ratios = self._get_calibration_data(model)
|
||||
self.assertEqual(len(ratios), 1)
|
||||
self.assertEqual(ratios[0], 1.0)
|
||||
|
||||
# Step 1
|
||||
# HeadInput = 10. Output = 40. Residual = 30.
|
||||
# PrevResidual = 30. CurrResidual = 30.
|
||||
# Ratio = 30/30 = 1.0
|
||||
model(torch.tensor([[[10.0]]]))
|
||||
|
||||
# Verify it computes fully (no skip)
|
||||
# If it skipped, output would be 41.0. It should be 40.0
|
||||
# Actually in test setup, input is same (10.0) so output 40.0.
|
||||
# Let's ensure list is empty after reset (end of step 1)
|
||||
ratios_after = self._get_calibration_data(model)
|
||||
self.assertEqual(ratios_after, [])
|
||||
@@ -27,7 +27,6 @@ from ..test_pipelines_common import (
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
FluxIPAdapterTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
PipelineTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
@@ -42,7 +41,6 @@ class FluxPipelineFastTests(
|
||||
FasterCacheTesterMixin,
|
||||
FirstBlockCacheTesterMixin,
|
||||
TaylorSeerCacheTesterMixin,
|
||||
MagCacheTesterMixin,
|
||||
unittest.TestCase,
|
||||
):
|
||||
pipeline_class = FluxPipeline
|
||||
|
||||
@@ -35,7 +35,6 @@ from diffusers import (
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
|
||||
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
|
||||
from diffusers.hooks.mag_cache import MagCacheConfig
|
||||
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
|
||||
from diffusers.hooks.taylorseer_cache import TaylorSeerCacheConfig
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
@@ -2977,59 +2976,6 @@ class TaylorSeerCacheTesterMixin:
|
||||
)
|
||||
|
||||
|
||||
class MagCacheTesterMixin:
|
||||
mag_cache_config = MagCacheConfig(
|
||||
threshold=0.06,
|
||||
max_skip_steps=3,
|
||||
retention_ratio=0.2,
|
||||
num_inference_steps=50,
|
||||
mag_ratios=torch.ones(50),
|
||||
)
|
||||
|
||||
def test_mag_cache_inference(self, expected_atol: float = 0.1):
|
||||
device = "cpu"
|
||||
|
||||
def create_pipe():
|
||||
torch.manual_seed(0)
|
||||
num_layers = 2
|
||||
components = self.get_dummy_components(num_layers=num_layers)
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
return pipe
|
||||
|
||||
def run_forward(pipe):
|
||||
torch.manual_seed(0)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
# Match the config steps
|
||||
inputs["num_inference_steps"] = 50
|
||||
return pipe(**inputs)[0]
|
||||
|
||||
# 1. Run inference without MagCache (Baseline)
|
||||
pipe = create_pipe()
|
||||
output = run_forward(pipe).flatten()
|
||||
original_image_slice = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 2. Run inference with MagCache ENABLED
|
||||
pipe = create_pipe()
|
||||
pipe.transformer.enable_cache(self.mag_cache_config)
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_enabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
# 3. Run inference with MagCache DISABLED
|
||||
pipe.transformer.disable_cache()
|
||||
output = run_forward(pipe).flatten()
|
||||
image_slice_disabled = np.concatenate((output[:8], output[-8:]))
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_enabled, atol=expected_atol), (
|
||||
"MagCache outputs should not differ too much from baseline."
|
||||
)
|
||||
|
||||
assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-4), (
|
||||
"Outputs after disabling cache should match original inference exactly."
|
||||
)
|
||||
|
||||
|
||||
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
|
||||
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
|
||||
# reference image.
|
||||
|
||||
@@ -1,396 +0,0 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class ZImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ZImageInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(["image", "mask_image"])
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"strength",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = ZImageTransformer2DModel(
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=32,
|
||||
n_layers=2,
|
||||
n_refiner_layers=1,
|
||||
n_heads=2,
|
||||
n_kv_heads=2,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=16,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty` which contains
|
||||
# uninitialized memory. Set them to known values for deterministic test behavior.
|
||||
with torch.no_grad():
|
||||
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
|
||||
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
block_out_channels=[32, 64],
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=32,
|
||||
sample_size=32,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3Config(
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
text_encoder = Qwen3Model(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
import random
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
# Create mask: 1 = inpaint region, 0 = preserve region
|
||||
mask_image = torch.zeros((1, 1, 32, 32), device=device)
|
||||
mask_image[:, :, 8:24, 8:24] = 1.0 # Inpaint center region
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"strength": 1.0,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"cfg_normalization": False,
|
||||
"cfg_truncation": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (32, 32, 3))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.7):
|
||||
import random
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
# Generate a larger image for the input
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
# Generate a larger mask for the input
|
||||
mask = torch.zeros((1, 1, 128, 128), device="cpu")
|
||||
mask[:, :, 32:96, 32:96] = 1.0
|
||||
inputs["mask_image"] = mask
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
||||
pipe.vae.enable_tiling()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
inputs["mask_image"] = mask
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-3):
|
||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
||||
# Inpainting mask blending adds additional numerical variance
|
||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_group_offloading_inference(self):
|
||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
||||
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_strength_parameter(self):
|
||||
"""Test that strength parameter affects the output correctly."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Test with different strength values
|
||||
inputs_low_strength = self.get_dummy_inputs(device)
|
||||
inputs_low_strength["strength"] = 0.2
|
||||
|
||||
inputs_high_strength = self.get_dummy_inputs(device)
|
||||
inputs_high_strength["strength"] = 0.8
|
||||
|
||||
# Both should complete without errors
|
||||
output_low = pipe(**inputs_low_strength).images[0]
|
||||
output_high = pipe(**inputs_high_strength).images[0]
|
||||
|
||||
# Outputs should be different (different amount of transformation)
|
||||
self.assertFalse(np.allclose(output_low, output_high, atol=1e-3))
|
||||
|
||||
def test_invalid_strength(self):
|
||||
"""Test that invalid strength values raise appropriate errors."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
# Test strength < 0
|
||||
inputs["strength"] = -0.1
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
# Test strength > 1
|
||||
inputs["strength"] = 1.5
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_mask_inpainting(self):
|
||||
"""Test that the mask properly controls which regions are inpainted."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Generate with full mask (inpaint everything)
|
||||
inputs_full = self.get_dummy_inputs(device)
|
||||
inputs_full["mask_image"] = torch.ones((1, 1, 32, 32), device=device)
|
||||
|
||||
# Generate with no mask (preserve everything)
|
||||
inputs_none = self.get_dummy_inputs(device)
|
||||
inputs_none["mask_image"] = torch.zeros((1, 1, 32, 32), device=device)
|
||||
|
||||
# Both should complete without errors
|
||||
output_full = pipe(**inputs_full).images[0]
|
||||
output_none = pipe(**inputs_none).images[0]
|
||||
|
||||
# Outputs should be different (full inpaint vs preserve)
|
||||
self.assertFalse(np.allclose(output_full, output_none, atol=1e-3))
|
||||
Reference in New Issue
Block a user