mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-08 20:05:05 +08:00
Compare commits
14 Commits
ltx2-add-c
...
modular-do
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
86fc6691cb | ||
|
|
7224beb036 | ||
|
|
64dba68e0a | ||
|
|
98ea6e0b2e | ||
|
|
64a90fc2e2 | ||
|
|
7fdddf012e | ||
|
|
24cbb354c0 | ||
|
|
025dfd4c67 | ||
|
|
ca79f8ccc4 | ||
|
|
99e2cfff27 | ||
|
|
a3dcd9882f | ||
|
|
9fe0a9cac4 | ||
|
|
03af690b60 | ||
|
|
90818e82b3 |
@@ -53,6 +53,41 @@ image = pipe(
|
||||
image.save("zimage_img2img.png")
|
||||
```
|
||||
|
||||
## Inpainting
|
||||
|
||||
Use [`ZImageInpaintPipeline`] to inpaint specific regions of an image based on a text prompt and mask.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from diffusers import ZImageInpaintPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
# Create a mask (white = inpaint, black = preserve)
|
||||
mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
mask_image = Image.fromarray(mask)
|
||||
|
||||
prompt = "A beautiful lake with mountains in the background"
|
||||
image = pipe(
|
||||
prompt,
|
||||
image=init_image,
|
||||
mask_image=mask_image,
|
||||
strength=1.0,
|
||||
num_inference_steps=9,
|
||||
guidance_scale=0.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).images[0]
|
||||
image.save("zimage_inpaint.png")
|
||||
```
|
||||
|
||||
## ZImagePipeline
|
||||
|
||||
[[autodoc]] ZImagePipeline
|
||||
@@ -64,3 +99,9 @@ image.save("zimage_img2img.png")
|
||||
[[autodoc]] ZImageImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ZImageInpaintPipeline
|
||||
|
||||
[[autodoc]] ZImageInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -12,27 +12,28 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ModularPipeline
|
||||
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`] into an executable pipeline that loads models and performs the computation steps defined in the blocks. It is the main interface for running a pipeline and the API is very similar to [`DiffusionPipeline`] but with a few key differences.
|
||||
|
||||
The main difference is to include an expected `output` argument in the pipeline.
|
||||
- **Loading is lazy.** With [`DiffusionPipeline`], [`~DiffusionPipeline.from_pretrained`] creates the pipeline and loads all models at the same time. With [`ModularPipeline`], creating and loading are two separate steps: [`~ModularPipeline.from_pretrained`] reads the configuration and knows where to load each component from, but doesn't actually load the model weights. You load the models later with [`~ModularPipeline.load_components`], which is where you pass loading arguments like `torch_dtype` and `quantization_config`.
|
||||
|
||||
- **Two ways to create a pipeline.** You can use [`~ModularPipeline.from_pretrained`] with an existing diffusers model repository — it automatically maps to the default pipeline blocks and then converts to a [`ModularPipeline`] with no extra setup. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2. You can also assemble your own pipeline from [`ModularPipelineBlocks`] and convert it with the [`~ModularPipelineBlocks.init_pipeline`] method (see [Creating a pipeline](#creating-a-pipeline) for more details).
|
||||
|
||||
- **Running the pipeline is the same.** Once loaded, you call the pipeline with the same arguments you're used to. A single [`ModularPipeline`] can support multiple workflows (text-to-image, image-to-image, inpainting, etc.) when the pipeline blocks use [`AutoPipelineBlocks`](./auto_pipeline) to automatically select the workflow based on your inputs.
|
||||
|
||||
Below are complete examples for text-to-image, image-to-image, and inpainting with SDXL.
|
||||
|
||||
<hfoptions id="example">
|
||||
<hfoption id="text-to-image">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
|
||||
image.save("modular_t2i_out.png")
|
||||
```
|
||||
|
||||
@@ -41,21 +42,17 @@ image.save("modular_t2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
init_image = load_image(url)
|
||||
prompt = "a dog catching a frisbee in the jungle"
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8).images[0]
|
||||
image.save("modular_i2i_out.png")
|
||||
```
|
||||
|
||||
@@ -64,15 +61,10 @@ image.save("modular_i2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
@@ -83,276 +75,353 @@ init_image = load_image(img_url)
|
||||
mask_image = load_image(mask_url)
|
||||
|
||||
prompt = "A deep sea diver floating"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
|
||||
image.save("moduar_inpaint_out.png")
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85).images[0]
|
||||
image.save("modular_inpaint_out.png")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
|
||||
|
||||
## Adding blocks
|
||||
|
||||
Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
|
||||
|
||||
```py
|
||||
# BLOCKS is dict of block classes, you need to add class to it
|
||||
BLOCKS.insert("block_name", BlockClass, index)
|
||||
# sub_blocks attribute contains instance, add a block instance to the attribute
|
||||
t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
|
||||
```
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
|
||||
|
||||
```py
|
||||
# remove a block class from preset
|
||||
BLOCKS.pop("text_encoder")
|
||||
# split out a block instance on its own
|
||||
text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
|
||||
```
|
||||
|
||||
Swap blocks by setting the existing block to the new block.
|
||||
|
||||
```py
|
||||
# Replace block class in preset
|
||||
BLOCKS["prepare_latents"] = CustomPrepareLatents
|
||||
# Replace in sub_blocks attribute using an block instance
|
||||
t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
|
||||
```
|
||||
This guide will show you how to create a [`ModularPipeline`], manage the components in it, and run it.
|
||||
|
||||
## Creating a pipeline
|
||||
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] with [`~ModularPipelineBlocks.init_pipeline`], or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
|
||||
|
||||
<hfoptions id="create">
|
||||
<hfoption id="ModularPipelineBlocks">
|
||||
### init_pipeline
|
||||
|
||||
Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
|
||||
[`~ModularPipelineBlocks.init_pipeline`] converts any [`ModularPipelineBlocks`] into a [`ModularPipeline`].
|
||||
|
||||
Let's define a minimal block to see how it works:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentsManager
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
from transformers import CLIPTextModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
class MyBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="text_encoder",
|
||||
type_hint=CLIPTextModel,
|
||||
pretrained_model_name_or_path="openai/clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
components = ComponentsManager()
|
||||
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
return components, state
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="from_pretrained">
|
||||
Call [`~ModularPipelineBlocks.init_pipeline`] to convert it into a pipeline. The `blocks` attribute on the pipeline is the blocks it was created from — it determines the expected inputs, outputs, and computation logic.
|
||||
|
||||
The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
|
||||
```py
|
||||
block = MyBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.blocks
|
||||
```
|
||||
|
||||
```
|
||||
MyBlock {
|
||||
"_class_name": "MyBlock",
|
||||
"_diffusers_version": "0.37.0.dev0"
|
||||
}
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Blocks are mutable — you can freely add, remove, or swap blocks before creating a pipeline. However, once a pipeline is created, modifying `pipeline.blocks` won't affect the pipeline because it returns a copy. If you want a different block structure, create a new pipeline after modifying the blocks.
|
||||
|
||||
When you call [`~ModularPipelineBlocks.init_pipeline`] without a repository, it uses the `pretrained_model_name_or_path` defined in the block's [`ComponentSpec`] to determine where to load each component from. Printing the pipeline shows the component loading configuration.
|
||||
|
||||
```py
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"revision": null,
|
||||
"subfolder": "",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If you pass a repository to [`~ModularPipelineBlocks.init_pipeline`], it overrides the loading path by matching your block's components against the pipeline config in that repository (`model_index.json` or `modular_model_index.json`).
|
||||
|
||||
In the example below, the `pretrained_model_name_or_path` will be updated to `"stabilityai/stable-diffusion-xl-base-1.0"`.
|
||||
|
||||
```py
|
||||
pipe = block.init_pipeline("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If a component in your block doesn't exist in the repository, it remains `null` and is skipped during [`~ModularPipeline.load_components`].
|
||||
|
||||
### from_pretrained
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] is a convenient way to create a [`ModularPipeline`] without defining blocks yourself.
|
||||
|
||||
It works with three types of repositories.
|
||||
|
||||
**A regular diffusers repository.** Pass any supported model repository and it automatically maps to the default pipeline blocks. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
|
||||
**A modular repository.** These repositories contain a `modular_model_index.json` that specifies where to load each component from — the components can come from different repositories and the modular repository itself may not contain any model weights. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from one repository and the remaining components from another. See [Modular repository](#modular-repository) for more details on the format.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
modular_repo_id = "YiYiXu/modular-diffdiff-0704"
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/flux2-bnb-4bit-modular", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
**A modular repository with custom code.** Some repositories include custom pipeline blocks alongside the loading configuration. Add `trust_remote_code=True` to load them. See [Custom blocks](./custom_blocks) for how to create your own.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/Florence2-image-Annotator", trust_remote_code=True, components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
## Loading components
|
||||
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load components with [`~ModularPipeline.load_components`].
|
||||
|
||||
<hfoptions id="load">
|
||||
<hfoption id="load_components">
|
||||
This will load all the components that have a valid loading spec.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.to("cuda")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="load_components">
|
||||
|
||||
The example below only loads the UNet and VAE.
|
||||
You can also load specific components by name. The example below only loads the text_encoder.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
|
||||
pipeline.load_components(names=["text_encoder"], torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Print the pipeline to inspect the loaded pretrained components.
|
||||
After loading, printing the pipeline shows which components are loaded — the first two fields change from `null` to the component's library and class.
|
||||
|
||||
```py
|
||||
t2i_pipeline
|
||||
pipeline
|
||||
```
|
||||
|
||||
This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
|
||||
|
||||
To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
|
||||
|
||||
```json
|
||||
# original
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
```
|
||||
# text_encoder is loaded - shows library and class
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel",
|
||||
{ ... }
|
||||
]
|
||||
|
||||
# modified
|
||||
# unet is not loaded yet - still null
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "RunDiffusion/Juggernaut-XL-v9",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
null,
|
||||
null,
|
||||
{ ... }
|
||||
]
|
||||
```
|
||||
|
||||
### Component loading status
|
||||
|
||||
The pipeline properties below provide more information about which components are loaded.
|
||||
|
||||
Use `component_names` to return all expected components.
|
||||
Loading keyword arguments like `torch_dtype`, `variant`, `revision`, and `quantization_config` are passed through to `from_pretrained()` for each component. You can pass a single value to apply to all components, or a dict to set per-component values.
|
||||
|
||||
```py
|
||||
t2i_pipeline.component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
|
||||
# apply bfloat16 to all components
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# different dtypes per component
|
||||
pipeline.load_components(torch_dtype={"transformer": torch.bfloat16, "default": torch.float32})
|
||||
```
|
||||
|
||||
Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
t2i_pipeline.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
|
||||
```
|
||||
|
||||
Use `pretrained_component_names` to return components that will be loaded from pretrained models.
|
||||
|
||||
```py
|
||||
t2i_pipeline.pretrained_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
|
||||
```
|
||||
|
||||
Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
|
||||
|
||||
```py
|
||||
t2i_pipeline.config_component_names
|
||||
['guider', 'image_processor']
|
||||
```
|
||||
Note that [`~ModularPipeline.load_components`] only loads components that haven't been loaded yet and have a valid loading spec. This means if you've already set a component on the pipeline, calling [`~ModularPipeline.load_components`] again won't reload it.
|
||||
|
||||
## Updating components
|
||||
|
||||
Components may be updated depending on whether it is a *pretrained component* or a *config component*.
|
||||
[`~ModularPipeline.update_components`] replaces a component on the pipeline with a new one. When a component is updated, the loading specifications are also updated in the pipeline config and [`~ModularPipeline.load_components`] will skip it on subsequent calls.
|
||||
|
||||
> [!WARNING]
|
||||
> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
|
||||
### From AutoModel
|
||||
|
||||
A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
|
||||
|
||||
The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
|
||||
|
||||
To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
|
||||
You can pass a model object loaded with `AutoModel.from_pretrained()`. Models loaded this way are automatically tagged with their loading information.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, UNet2DConditionModel
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.update_components`] method replaces the component with a new one.
|
||||
|
||||
```py
|
||||
t2i_pipeline.update_components(unet=unet2)
|
||||
```
|
||||
|
||||
When a component is updated, the loading specifications are also updated in the pipeline config.
|
||||
|
||||
### Component extraction and modification
|
||||
|
||||
When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
|
||||
|
||||
```py
|
||||
spec = ComponentSpec.from_component("unet", unet2)
|
||||
spec
|
||||
ComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
|
||||
unet2_recreated = spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
|
||||
|
||||
```py
|
||||
unet_spec = t2i_pipeline.get_component_spec("unet")
|
||||
unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
unet = AutoModel.from_pretrained(
|
||||
"RunDiffusion/Juggernaut-XL-v9", subfolder="unet", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
### From ComponentSpec
|
||||
|
||||
Use [`~ModularPipeline.get_component_spec`] to get a copy of the current component specification, modify it, and load a new component.
|
||||
|
||||
```py
|
||||
unet_spec = pipeline.get_component_spec("unet")
|
||||
|
||||
# modify to load from a different repository
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "RunDiffusion/Juggernaut-XL-v9"
|
||||
|
||||
# load component with modified spec
|
||||
# load and update
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
You can also create a [`ComponentSpec`] from scratch.
|
||||
|
||||
Not all components are loaded from pretrained weights — some are created from a config (listed under `pipeline.config_component_names`). For these, use [`~ComponentSpec.create`] instead of [`~ComponentSpec.load`].
|
||||
|
||||
```py
|
||||
guider_spec = pipeline.get_component_spec("guider")
|
||||
guider_spec.config = {"guidance_scale": 5.0}
|
||||
guider = guider_spec.create()
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
Or simply pass the object directly.
|
||||
|
||||
```py
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
|
||||
guider = ClassifierFreeGuidance(guidance_scale=5.0)
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
|
||||
|
||||
## Splitting a pipeline into stages
|
||||
|
||||
Since blocks are composable, you can take a pipeline apart and reconstruct it into separate pipelines for each stage. The example below shows how we can separate the text encoder block from the rest of the pipeline, so you can encode the prompt independently and pass the embeddings to the main pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||
|
||||
# get the blocks and separate out the text encoder
|
||||
blocks = ModularPipeline.from_pretrained(repo_id).blocks
|
||||
text_block = blocks.sub_blocks.pop("text_encoder")
|
||||
|
||||
# use ComponentsManager to handle offloading across multiple pipelines
|
||||
manager = ComponentsManager()
|
||||
manager.enable_auto_cpu_offload(device=device)
|
||||
|
||||
# create separate pipelines for each stage
|
||||
text_encoder_pipeline = text_block.init_pipeline(repo_id, components_manager=manager)
|
||||
pipeline = blocks.init_pipeline(repo_id, components_manager=manager)
|
||||
|
||||
# encode text
|
||||
text_encoder_pipeline.load_components(torch_dtype=dtype)
|
||||
text_embeddings = text_encoder_pipeline(prompt="a cat").get_by_kwargs("denoiser_input_fields")
|
||||
|
||||
# denoise and decode
|
||||
pipeline.load_components(torch_dtype=dtype)
|
||||
output = pipeline(
|
||||
**text_embeddings,
|
||||
num_inference_steps=4,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
[`ComponentsManager`] handles memory across multiple pipelines. Unlike the offloading strategies in [`DiffusionPipeline`] that follow a fixed order, [`ComponentsManager`] makes offloading decisions dynamically each time a model forward pass runs, based on the current memory situation. This means it works regardless of how many pipelines you create or what order you run them in. See the [ComponentsManager](./components_manager) guide for more details.
|
||||
|
||||
If pipeline stages share components (e.g., the same VAE used for encoding and decoding), you can use [`~ModularPipeline.update_components`] to pass an already-loaded component to another pipeline instead of loading it again.
|
||||
|
||||
## Modular repository
|
||||
|
||||
A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
|
||||
|
||||
[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
[`ModularPipeline`] works with regular diffusers repositories out of the box. However, you can also create a *modular repository* for more flexibility. A modular repository contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
|
||||
- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
|
||||
- `library` and `class` shows which library the component was loaded from and its class. If `null`, the component hasn't been loaded yet.
|
||||
- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
|
||||
|
||||
Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
|
||||
The key advantage of a modular repository is that components can be loaded from different repositories. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from `diffusers/FLUX.2-dev-bnb-4bit` while loading the remaining components from `black-forest-labs/FLUX.2-dev`.
|
||||
|
||||
A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
|
||||
To convert a regular diffusers repository into a modular one, create the pipeline using the regular repository, and then push to the Hub. The saved repository will contain a `modular_model_index.json` with all the loading specifications.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
# load from a regular repo
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# push as a modular repository
|
||||
pipeline.save_pretrained("local/path", repo_id="my-username/sdxl-modular", push_to_hub=True)
|
||||
```
|
||||
|
||||
A modular repository can also include custom pipeline blocks as Python code. This allows you to share specialized blocks that aren't native to Diffusers. For example, [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator) contains custom blocks alongside the loading configuration:
|
||||
|
||||
```
|
||||
modular-diffdiff-0704/
|
||||
Florence2-image-Annotator/
|
||||
├── block.py # Custom pipeline blocks implementation
|
||||
├── config.json # Pipeline configuration and auto_map
|
||||
├── mellon_config.json # UI configuration for Mellon
|
||||
└── modular_model_index.json # Component loading specifications
|
||||
```
|
||||
|
||||
The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
|
||||
The `config.json` file contains an `auto_map` key that tells [`ModularPipeline`] where to find the custom blocks:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "DiffDiffBlocks",
|
||||
"_class_name": "Florence2AnnotatorBlocks",
|
||||
"auto_map": {
|
||||
"ModularPipelineBlocks": "block.DiffDiffBlocks"
|
||||
"ModularPipelineBlocks": "block.Florence2AnnotatorBlocks"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Load custom code repositories with `trust_remote_code=True` as shown in [from_pretrained](#from_pretrained). See [Custom blocks](./custom_blocks) for how to create and share your own.
|
||||
@@ -66,7 +66,7 @@ from diffusers import DiffusionPipeline, PipelineQuantizationConfig, TorchAoConf
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
pipeline_quant_config = PipelineQuantizationConfig(
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128)))}
|
||||
quant_mapping={"transformer": TorchAoConfig(Int4WeightOnlyConfig(group_size=128))}
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
|
||||
@@ -554,7 +554,6 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
@@ -697,6 +696,7 @@ else:
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
@@ -1289,7 +1289,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
@@ -1430,6 +1429,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -125,9 +125,9 @@ class BriaFiboAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -130,9 +130,9 @@ class FluxAttnProcessor:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[0](hidden_states.contiguous())
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states.contiguous())
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
|
||||
@@ -561,11 +561,11 @@ class QwenDoubleStreamAttnProcessor2_0:
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
||||
|
||||
# Apply output projections
|
||||
img_attn_output = attn.to_out[0](img_attn_output)
|
||||
img_attn_output = attn.to_out[0](img_attn_output.contiguous())
|
||||
if len(attn.to_out) > 1:
|
||||
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
||||
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output)
|
||||
txt_attn_output = attn.to_add_out(txt_attn_output.contiguous())
|
||||
|
||||
return img_attn_output, txt_attn_output
|
||||
|
||||
|
||||
@@ -2016,58 +2016,29 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
|
||||
Args:
|
||||
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
|
||||
- Component objects: Only supports components we can extract specs using
|
||||
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
|
||||
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
|
||||
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
|
||||
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g.,
|
||||
`requires_safety_checker=False`)
|
||||
|
||||
Raises:
|
||||
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
|
||||
- nn.Module components without a valid `_diffusers_load_id` attribute
|
||||
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
|
||||
**kwargs: Component objects or configuration values to update:
|
||||
- Component objects: Models loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()`
|
||||
are automatically tagged with loading information. ConfigMixin objects without weights (e.g.,
|
||||
schedulers, guiders) can be passed directly.
|
||||
- Configuration values: Simple values to update configuration settings
|
||||
(e.g., `requires_safety_checker=False`)
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Update multiple components at once
|
||||
# Update pretrrained model
|
||||
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
|
||||
|
||||
# Update configuration values
|
||||
pipeline.update_components(requires_safety_checker=False)
|
||||
|
||||
# Update both components and configs together
|
||||
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
|
||||
|
||||
# Update with ComponentSpec objects (from_config only)
|
||||
pipeline.update_components(
|
||||
guider=ComponentSpec(
|
||||
name="guider",
|
||||
type_hint=ClassifierFreeGuidance,
|
||||
config={"guidance_scale": 5.0},
|
||||
default_creation_method="from_config",
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
|
||||
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
|
||||
update_components()
|
||||
- Components with trained weights should be loaded with `AutoModel.from_pretrained()` or
|
||||
`ComponentSpec.load()` so that loading specs are preserved for serialization.
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
passed_component_specs = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
|
||||
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
|
||||
|
||||
for name, component in passed_components.items():
|
||||
@@ -2106,33 +2077,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
created_components = {}
|
||||
for name, component_spec in passed_component_specs.items():
|
||||
if component_spec.default_creation_method == "from_pretrained":
|
||||
raise ValueError(
|
||||
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
|
||||
)
|
||||
created_components[name] = component_spec.create()
|
||||
current_component_spec = self._component_specs[name]
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(
|
||||
created_components[name], current_component_spec.type_hint
|
||||
):
|
||||
logger.info(
|
||||
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
)
|
||||
# update _component_specs based on the user passed component_spec
|
||||
self._component_specs[name] = component_spec
|
||||
self.register_components(**passed_components, **created_components)
|
||||
self.register_components(**passed_components)
|
||||
|
||||
config_to_register = {}
|
||||
for name, new_value in passed_config_values.items():
|
||||
# e.g. requires_aesthetics_score = False
|
||||
self._config_specs[name].default = new_value
|
||||
config_to_register[name] = new_value
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
|
||||
@@ -291,12 +291,7 @@ else:
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = [
|
||||
"LTX2Pipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -415,11 +410,12 @@ else:
|
||||
"Kandinsky5I2IPipeline",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImagePipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageInpaintPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
_import_structure["skyreels_v2"] = [
|
||||
"SkyReelsV2DiffusionForcingPipeline",
|
||||
@@ -747,7 +743,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
@@ -875,6 +871,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
@@ -127,6 +127,7 @@ from .z_image import (
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
@@ -235,6 +236,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
||||
("qwenimage", QwenImageInpaintPipeline),
|
||||
("qwenimage-edit", QwenImageEditInpaintPipeline),
|
||||
("z-image", ZImageInpaintPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ else:
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
@@ -41,7 +40,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_condition import LTX2ConditionPipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,7 @@ else:
|
||||
_import_structure["pipeline_z_image_controlnet"] = ["ZImageControlNetPipeline"]
|
||||
_import_structure["pipeline_z_image_controlnet_inpaint"] = ["ZImageControlNetInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_img2img"] = ["ZImageImg2ImgPipeline"]
|
||||
_import_structure["pipeline_z_image_inpaint"] = ["ZImageInpaintPipeline"]
|
||||
_import_structure["pipeline_z_image_omni"] = ["ZImageOmniPipeline"]
|
||||
|
||||
|
||||
@@ -42,6 +43,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_z_image_controlnet import ZImageControlNetPipeline
|
||||
from .pipeline_z_image_controlnet_inpaint import ZImageControlNetInpaintPipeline
|
||||
from .pipeline_z_image_img2img import ZImageImg2ImgPipeline
|
||||
from .pipeline_z_image_inpaint import ZImageInpaintPipeline
|
||||
from .pipeline_z_image_omni import ZImageOmniPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
932
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Normal file
932
src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py
Normal file
@@ -0,0 +1,932 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import ZImageTransformer2DModel
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import ZImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import ZImageInpaintPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = ZImageInpaintPipeline.from_pretrained("Tongyi-MAI/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> init_image = load_image(url).resize((1024, 1024))
|
||||
|
||||
>>> # Create a mask (white = inpaint, black = preserve)
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
|
||||
>>> mask = np.zeros((1024, 1024), dtype=np.uint8)
|
||||
>>> mask[256:768, 256:768] = 255 # Inpaint center region
|
||||
>>> mask_image = Image.fromarray(mask)
|
||||
|
||||
>>> prompt = "A beautiful lake with mountains in the background"
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... image=init_image,
|
||||
... mask_image=mask_image,
|
||||
... strength=1.0,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(42),
|
||||
... ).images[0]
|
||||
>>> image.save("zimage_inpaint.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The ZImage pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`PreTrainedModel`]):
|
||||
A text encoder model to encode text prompts.
|
||||
tokenizer ([`AutoTokenizer`]):
|
||||
A tokenizer to tokenize text prompts.
|
||||
transformer ([`ZImageTransformer2DModel`]):
|
||||
A ZImage transformer model to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "mask", "masked_image_latents"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor * 2,
|
||||
do_normalize=False,
|
||||
do_binarize=True,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
"""Prepare mask and masked image latents for inpainting.
|
||||
|
||||
Args:
|
||||
mask: Binary mask tensor where 1 = inpaint region, 0 = preserve region.
|
||||
masked_image: Original image with masked regions zeroed out.
|
||||
batch_size: Number of images to generate.
|
||||
height: Output image height.
|
||||
width: Output image width.
|
||||
dtype: Data type for the tensors.
|
||||
device: Device to place tensors on.
|
||||
generator: Random generator for reproducibility.
|
||||
|
||||
Returns:
|
||||
Tuple of (mask, masked_image_latents) prepared for the denoising loop.
|
||||
"""
|
||||
# Calculate latent dimensions
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
# Resize mask to latent dimensions
|
||||
mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width), mode="nearest")
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode masked image to latents
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
masked_image_latents = [
|
||||
retrieve_latents(self.vae.encode(masked_image[i : i + 1]), generator=generator[i])
|
||||
for i in range(masked_image.shape[0])
|
||||
]
|
||||
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
# Apply VAE scaling
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# Expand for batch size
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
"""Prepare latents for inpainting, returning noise and image_latents for blending.
|
||||
|
||||
Returns:
|
||||
Tuple of (latents, noise, image_latents) where:
|
||||
- latents: Noised image latents for denoising
|
||||
- noise: The noise tensor used for blending
|
||||
- image_latents: Clean image latents for blending
|
||||
"""
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
|
||||
if latents is not None:
|
||||
# Generate noise for blending even if latents are provided
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
# Encode image for blending
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
image_latents = torch.cat([image_latents] * (batch_size // image_latents.shape[0]), dim=0)
|
||||
return latents.to(device=device, dtype=dtype), noise, image_latents
|
||||
|
||||
# Encode the input image
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != num_channels_latents:
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
# Handle batch size expansion
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
# Generate noise for both initial noising and later blending
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# Add noise using flow matching scale_noise
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
||||
|
||||
return latents, noise, image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
image,
|
||||
mask_image,
|
||||
strength,
|
||||
height,
|
||||
width,
|
||||
output_type,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if image is None:
|
||||
raise ValueError("`image` input cannot be undefined for inpainting.")
|
||||
|
||||
if mask_image is None:
|
||||
raise ValueError("`mask_image` input cannot be undefined for inpainting.")
|
||||
|
||||
if output_type not in ["latent", "pil", "np", "pt"]:
|
||||
raise ValueError(f"`output_type` must be one of 'latent', 'pil', 'np', or 'pt', but got {output_type}")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
masked_image_latents: Optional[torch.FloatTensor] = None,
|
||||
strength: float = 1.0,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for inpainting.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
|
||||
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
|
||||
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing a mask image for inpainting. White pixels (value 1) in the
|
||||
mask will be inpainted, black pixels (value 0) will be preserved from the original image.
|
||||
masked_image_latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-encoded masked image latents. If provided, the masked image encoding step will be skipped.
|
||||
strength (`float`, *optional*, defaults to 1.0):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image` in the masked region.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image. If not provided, uses the input image height.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image. If not provided, uses the input image width.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
image=image,
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
height=height,
|
||||
width=width,
|
||||
output_type=output_type,
|
||||
negative_prompt=negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Preprocess image and mask
|
||||
init_image = self.image_processor.preprocess(image)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# Get dimensions from the preprocessed image if not specified
|
||||
if height is None:
|
||||
height = init_image.shape[-2]
|
||||
if width is None:
|
||||
width = init_image.shape[-1]
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
# Preprocess mask
|
||||
mask = self.mask_processor.preprocess(mask_image, height=height, width=width)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
# Calculate latent dimensions for image_seq_len
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
|
||||
# 6. Adjust timesteps based on strength
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1].repeat(actual_batch_size)
|
||||
|
||||
# 7. Prepare latents from image (returns noise and image_latents for blending)
|
||||
latents, noise, image_latents = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
actual_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 8. Prepare mask and masked image latents
|
||||
# Create masked image: preserve only unmasked regions (mask=0)
|
||||
if masked_image_latents is None:
|
||||
masked_image = init_image * (mask < 0.5)
|
||||
else:
|
||||
masked_image = None # Will use provided masked_image_latents
|
||||
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image if masked_image is not None else init_image,
|
||||
actual_batch_size,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 9. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
# Inpainting blend: combine denoised latents with original image latents
|
||||
init_latents_proper = image_latents
|
||||
|
||||
# Re-scale original latents to current noise level for proper blending
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
init_latents_proper = self.scheduler.scale_noise(
|
||||
init_latents_proper, torch.tensor([noise_timestep]), noise
|
||||
)
|
||||
|
||||
# Blend: mask=1 for inpaint region (use denoised), mask=0 for preserve region (use original)
|
||||
latents = (1 - mask) * init_latents_proper + mask * latents
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
mask = callback_outputs.pop("mask", mask)
|
||||
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -79,7 +79,8 @@ MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
|
||||
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
|
||||
# there is no need to call any kernel for fp16/bf16
|
||||
if qweight_type in UNQUANTIZED_TYPES:
|
||||
return x @ qweight.T
|
||||
weight = dequantize_gguf_tensor(qweight)
|
||||
return x @ weight.T
|
||||
|
||||
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
|
||||
# contiguous batching and inefficient with diffusers' batching,
|
||||
|
||||
@@ -545,7 +545,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -867,7 +867,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -245,13 +245,26 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -259,7 +272,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -287,7 +308,12 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -724,7 +750,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -738,7 +764,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -822,7 +848,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -832,8 +858,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -860,7 +888,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -891,7 +922,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -901,7 +932,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -1014,7 +1045,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -1024,8 +1055,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -1106,7 +1139,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
@@ -1216,7 +1251,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample = sample.to(torch.float32)
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=torch.float32
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise.to(device=model_output.device, dtype=torch.float32)
|
||||
|
||||
@@ -141,6 +141,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The flow shift factor. Valid only when `use_flow_sigmas=True`.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
@@ -163,15 +167,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@@ -180,19 +184,32 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||
if (
|
||||
sum(
|
||||
[
|
||||
self.config.use_beta_sigmas,
|
||||
self.config.use_exponential_sigmas,
|
||||
self.config.use_karras_sigmas,
|
||||
]
|
||||
)
|
||||
> 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used."
|
||||
)
|
||||
if algorithm_type in ["dpmsolver", "sde-dpmsolver"]:
|
||||
deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead"
|
||||
deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", deprecation_message)
|
||||
deprecate(
|
||||
"algorithm_types dpmsolver and sde-dpmsolver",
|
||||
"1.0.0",
|
||||
deprecation_message,
|
||||
)
|
||||
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
@@ -200,7 +217,15 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||
elif beta_schedule == "scaled_linear":
|
||||
# this schedule is very specific to the latent diffusion model.
|
||||
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||
self.betas = (
|
||||
torch.linspace(
|
||||
beta_start**0.5,
|
||||
beta_end**0.5,
|
||||
num_train_timesteps,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
elif beta_schedule == "squaredcos_cap_v2":
|
||||
# Glide cosine schedule
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
@@ -219,7 +244,12 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# settings for DPM-Solver
|
||||
if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
if algorithm_type not in [
|
||||
"dpmsolver",
|
||||
"dpmsolver++",
|
||||
"sde-dpmsolver",
|
||||
"sde-dpmsolver++",
|
||||
]:
|
||||
if algorithm_type == "deis":
|
||||
self.register_to_config(algorithm_type="dpmsolver++")
|
||||
else:
|
||||
@@ -250,7 +280,11 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
@@ -382,7 +416,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -419,7 +453,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
@@ -441,7 +475,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
return alpha_t, sigma_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||
Models](https://huggingface.co/papers/2206.00364).
|
||||
@@ -567,7 +601,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -581,7 +615,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -666,7 +700,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -676,8 +710,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -704,7 +740,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
||||
)
|
||||
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -736,7 +775,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -746,7 +785,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
|
||||
Returns:
|
||||
@@ -860,7 +899,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
model_output_list: List[torch.Tensor],
|
||||
*args,
|
||||
sample: torch.Tensor = None,
|
||||
sample: Optional[torch.Tensor] = None,
|
||||
noise: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
@@ -870,8 +909,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output_list (`List[torch.Tensor]`):
|
||||
The direct outputs from learned diffusion model at current and latter timesteps.
|
||||
sample (`torch.Tensor`):
|
||||
sample (`torch.Tensor`, *optional*):
|
||||
A current instance of a sample created by diffusion process.
|
||||
noise (`torch.Tensor`, *optional*):
|
||||
The noise tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@@ -951,7 +992,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
|
||||
@@ -975,7 +1016,7 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -1027,7 +1068,10 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"] and variance_noise is None:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]:
|
||||
noise = variance_noise
|
||||
@@ -1074,6 +1118,21 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the clean `original_samples` using the scheduler's equivalent function.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples to add noise to.
|
||||
noise (`torch.Tensor`):
|
||||
The noise tensor.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
@@ -1103,5 +1162,5 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -1120,7 +1120,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -662,7 +662,9 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1122,7 +1122,9 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -1083,7 +1083,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
self,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
schedule_timesteps: Optional[torch.Tensor] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
@@ -2012,21 +2012,6 @@ class LongCatImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2ConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -4127,6 +4112,21 @@ class ZImageImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ZImageOmniPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
396
tests/pipelines/z_image/test_z_image_inpaint.py
Normal file
396
tests/pipelines/z_image/test_z_image_inpaint.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImageInpaintPipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_INPAINTING_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
# Note: Z-Image does not support FP16 inference due to complex64 RoPE embeddings
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class ZImageInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ZImageInpaintPipeline
|
||||
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
|
||||
image_params = frozenset(["image", "mask_image"])
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"strength",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = ZImageTransformer2DModel(
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=32,
|
||||
n_layers=2,
|
||||
n_refiner_layers=1,
|
||||
n_heads=2,
|
||||
n_kv_heads=2,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=16,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty` which contains
|
||||
# uninitialized memory. Set them to known values for deterministic test behavior.
|
||||
with torch.no_grad():
|
||||
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
|
||||
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
block_out_channels=[32, 64],
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=32,
|
||||
sample_size=32,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3Config(
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
text_encoder = Qwen3Model(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
import random
|
||||
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
# Create mask: 1 = inpaint region, 0 = preserve region
|
||||
mask_image = torch.zeros((1, 1, 32, 32), device=device)
|
||||
mask_image[:, :, 8:24, 8:24] = 1.0 # Inpaint center region
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"image": image,
|
||||
"mask_image": mask_image,
|
||||
"strength": 1.0,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"cfg_normalization": False,
|
||||
"cfg_truncation": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "np",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (32, 32, 3))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.7):
|
||||
import random
|
||||
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
# Generate a larger image for the input
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
# Generate a larger mask for the input
|
||||
mask = torch.zeros((1, 1, 128, 128), device="cpu")
|
||||
mask[:, :, 32:96, 32:96] = 1.0
|
||||
inputs["mask_image"] = mask
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
||||
pipe.vae.enable_tiling()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
inputs["image"] = floats_tensor((1, 3, 128, 128), rng=random.Random(0)).to("cpu")
|
||||
inputs["mask_image"] = mask
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-3):
|
||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
||||
# Inpainting mask blending adds additional numerical variance
|
||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_group_offloading_inference(self):
|
||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
||||
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
# Z-Image does not support FP16 due to complex64 RoPE embeddings
|
||||
self.skipTest("Z-Image does not support FP16 inference")
|
||||
|
||||
def test_strength_parameter(self):
|
||||
"""Test that strength parameter affects the output correctly."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Test with different strength values
|
||||
inputs_low_strength = self.get_dummy_inputs(device)
|
||||
inputs_low_strength["strength"] = 0.2
|
||||
|
||||
inputs_high_strength = self.get_dummy_inputs(device)
|
||||
inputs_high_strength["strength"] = 0.8
|
||||
|
||||
# Both should complete without errors
|
||||
output_low = pipe(**inputs_low_strength).images[0]
|
||||
output_high = pipe(**inputs_high_strength).images[0]
|
||||
|
||||
# Outputs should be different (different amount of transformation)
|
||||
self.assertFalse(np.allclose(output_low, output_high, atol=1e-3))
|
||||
|
||||
def test_invalid_strength(self):
|
||||
"""Test that invalid strength values raise appropriate errors."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
|
||||
# Test strength < 0
|
||||
inputs["strength"] = -0.1
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
# Test strength > 1
|
||||
inputs["strength"] = 1.5
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(**inputs)
|
||||
|
||||
def test_mask_inpainting(self):
|
||||
"""Test that the mask properly controls which regions are inpainted."""
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Generate with full mask (inpaint everything)
|
||||
inputs_full = self.get_dummy_inputs(device)
|
||||
inputs_full["mask_image"] = torch.ones((1, 1, 32, 32), device=device)
|
||||
|
||||
# Generate with no mask (preserve everything)
|
||||
inputs_none = self.get_dummy_inputs(device)
|
||||
inputs_none["mask_image"] = torch.zeros((1, 1, 32, 32), device=device)
|
||||
|
||||
# Both should complete without errors
|
||||
output_full = pipe(**inputs_full).images[0]
|
||||
output_none = pipe(**inputs_none).images[0]
|
||||
|
||||
# Outputs should be different (full inpaint vs preserve)
|
||||
self.assertFalse(np.allclose(output_full, output_none, atol=1e-3))
|
||||
Reference in New Issue
Block a user