mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-21 20:04:37 +08:00
Compare commits
8 Commits
edit-pypi-
...
device-map
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f5eb0a933 | ||
|
|
6708f5c76d | ||
|
|
be3c2a0667 | ||
|
|
8b4722de57 | ||
|
|
07ea0786e8 | ||
|
|
83ec2fb793 | ||
|
|
54fa0745c3 | ||
|
|
3d02cd543e |
@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
|
|||||||
|
|
||||||
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
|
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
|
||||||
|
|
||||||
|
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
|
||||||
|
|
||||||
### Ring Attention
|
### Ring Attention
|
||||||
|
|
||||||
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
|
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
|
||||||
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
|
|||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
|
from torch import distributed as dist
|
||||||
|
from diffusers import DiffusionPipeline, ContextParallelConfig
|
||||||
|
|
||||||
try:
|
def setup_distributed():
|
||||||
torch.distributed.init_process_group("nccl")
|
if not dist.is_initialized():
|
||||||
rank = torch.distributed.get_rank()
|
dist.init_process_group(backend="nccl")
|
||||||
device = torch.device("cuda", rank % torch.cuda.device_count())
|
rank = dist.get_rank()
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
torch.cuda.set_device(device)
|
torch.cuda.set_device(device)
|
||||||
|
return device
|
||||||
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
|
|
||||||
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
|
def main():
|
||||||
pipeline.transformer.set_attention_backend("flash")
|
device = setup_distributed()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
|
||||||
|
)
|
||||||
|
pipeline.transformer.set_attention_backend("_native_cudnn")
|
||||||
|
|
||||||
|
cp_config = ContextParallelConfig(ring_degree=world_size)
|
||||||
|
pipeline.transformer.enable_parallelism(config=cp_config)
|
||||||
|
|
||||||
prompt = """
|
prompt = """
|
||||||
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
|
||||||
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Must specify generator so all ranks start with same latents (or pass your own)
|
# Must specify generator so all ranks start with same latents (or pass your own)
|
||||||
generator = torch.Generator().manual_seed(42)
|
generator = torch.Generator().manual_seed(42)
|
||||||
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
|
image = pipeline(
|
||||||
|
prompt,
|
||||||
if rank == 0:
|
guidance_scale=3.5,
|
||||||
image.save("output.png")
|
num_inference_steps=50,
|
||||||
|
generator=generator,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
except Exception as e:
|
if dist.get_rank() == 0:
|
||||||
print(f"An error occurred: {e}")
|
image.save(f"output.png")
|
||||||
torch.distributed.breakpoint()
|
|
||||||
raise
|
|
||||||
|
|
||||||
finally:
|
if dist.is_initialized():
|
||||||
if torch.distributed.is_initialized():
|
dist.destroy_process_group()
|
||||||
torch.distributed.destroy_process_group()
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
|
||||||
|
|
||||||
|
/```shell
|
||||||
|
`torchrun --nproc-per-node 2 above_script.py`.
|
||||||
|
/```
|
||||||
|
|
||||||
### Ulysses Attention
|
### Ulysses Attention
|
||||||
|
|
||||||
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
|
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
|
||||||
@@ -288,5 +310,26 @@ finally:
|
|||||||
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||||
|
|
||||||
```py
|
```py
|
||||||
|
# Depending on the number of GPUs available.
|
||||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||||
|
```
|
||||||
|
|
||||||
|
### parallel_config
|
||||||
|
|
||||||
|
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||||
|
|
||||||
|
```py
|
||||||
|
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||||
|
|
||||||
|
cp_config = ContextParallelConfig(ring_degree=2)
|
||||||
|
transformer = AutoModel.from_pretrained(
|
||||||
|
CKPT_ID,
|
||||||
|
subfolder="transformer",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
parallel_config=cp_config
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
|
||||||
|
).to(device)
|
||||||
```
|
```
|
||||||
@@ -404,6 +404,8 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modular_pipelines"].extend(
|
_import_structure["modular_pipelines"].extend(
|
||||||
[
|
[
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2ModularPipeline",
|
||||||
"FluxAutoBlocks",
|
"FluxAutoBlocks",
|
||||||
"FluxKontextAutoBlocks",
|
"FluxKontextAutoBlocks",
|
||||||
"FluxKontextModularPipeline",
|
"FluxKontextModularPipeline",
|
||||||
@@ -419,6 +421,8 @@ else:
|
|||||||
"Wan22AutoBlocks",
|
"Wan22AutoBlocks",
|
||||||
"WanAutoBlocks",
|
"WanAutoBlocks",
|
||||||
"WanModularPipeline",
|
"WanModularPipeline",
|
||||||
|
"ZImageAutoBlocks",
|
||||||
|
"ZImageModularPipeline",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["pipelines"].extend(
|
_import_structure["pipelines"].extend(
|
||||||
@@ -1109,6 +1113,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .modular_pipelines import (
|
from .modular_pipelines import (
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2ModularPipeline,
|
||||||
FluxAutoBlocks,
|
FluxAutoBlocks,
|
||||||
FluxKontextAutoBlocks,
|
FluxKontextAutoBlocks,
|
||||||
FluxKontextModularPipeline,
|
FluxKontextModularPipeline,
|
||||||
@@ -1124,6 +1130,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
Wan22AutoBlocks,
|
Wan22AutoBlocks,
|
||||||
WanAutoBlocks,
|
WanAutoBlocks,
|
||||||
WanModularPipeline,
|
WanModularPipeline,
|
||||||
|
ZImageAutoBlocks,
|
||||||
|
ZImageModularPipeline,
|
||||||
)
|
)
|
||||||
from .pipelines import (
|
from .pipelines import (
|
||||||
AllegroPipeline,
|
AllegroPipeline,
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.functional import fold, unfold
|
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
@@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
|
|||||||
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
|
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
|
||||||
// patch_size)` is the number of patches.
|
// patch_size)` is the number of patches.
|
||||||
"""
|
"""
|
||||||
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
|
b, c, h, w = img.shape
|
||||||
|
p = patch_size
|
||||||
|
|
||||||
|
# Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
|
||||||
|
img = img.reshape(b, c, h // p, p, w // p, p)
|
||||||
|
|
||||||
|
# Permute to (B, H//p, W//p, C, p, p) using einsum
|
||||||
|
# n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
|
||||||
|
img = torch.einsum("nchpwq->nhwcpq", img)
|
||||||
|
|
||||||
|
# Flatten to (B, L, C * p * p)
|
||||||
|
img = img.reshape(b, -1, c * p * p)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
||||||
@@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
|
|||||||
Reconstructed image tensor of shape `(B, C, H, W)`.
|
Reconstructed image tensor of shape `(B, C, H, W)`.
|
||||||
"""
|
"""
|
||||||
if isinstance(shape, tuple):
|
if isinstance(shape, tuple):
|
||||||
shape = shape[-2:]
|
h, w = shape[-2:]
|
||||||
elif isinstance(shape, torch.Tensor):
|
elif isinstance(shape, torch.Tensor):
|
||||||
shape = (int(shape[0]), int(shape[1]))
|
h, w = (int(shape[0]), int(shape[1]))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"shape type {type(shape)} not supported")
|
raise NotImplementedError(f"shape type {type(shape)} not supported")
|
||||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
|
||||||
|
b, l, d = seq.shape
|
||||||
|
p = patch_size
|
||||||
|
c = d // (p * p)
|
||||||
|
|
||||||
|
# Reshape back to grid structure: (B, H//p, W//p, C, p, p)
|
||||||
|
seq = seq.reshape(b, h // p, w // p, c, p, p)
|
||||||
|
|
||||||
|
# Permute back to image layout: (B, C, H//p, p, W//p, p)
|
||||||
|
# n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
|
||||||
|
seq = torch.einsum("nhwcpq->nchpwq", seq)
|
||||||
|
|
||||||
|
# Final reshape to (B, C, H, W)
|
||||||
|
seq = seq.reshape(b, c, h, w)
|
||||||
|
return seq
|
||||||
|
|
||||||
|
|
||||||
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ else:
|
|||||||
"FluxKontextAutoBlocks",
|
"FluxKontextAutoBlocks",
|
||||||
"FluxKontextModularPipeline",
|
"FluxKontextModularPipeline",
|
||||||
]
|
]
|
||||||
|
_import_structure["flux2"] = [
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2ModularPipeline",
|
||||||
|
]
|
||||||
_import_structure["qwenimage"] = [
|
_import_structure["qwenimage"] = [
|
||||||
"QwenImageAutoBlocks",
|
"QwenImageAutoBlocks",
|
||||||
"QwenImageModularPipeline",
|
"QwenImageModularPipeline",
|
||||||
@@ -60,6 +64,10 @@ else:
|
|||||||
"QwenImageEditPlusModularPipeline",
|
"QwenImageEditPlusModularPipeline",
|
||||||
"QwenImageEditPlusAutoBlocks",
|
"QwenImageEditPlusAutoBlocks",
|
||||||
]
|
]
|
||||||
|
_import_structure["z_image"] = [
|
||||||
|
"ZImageAutoBlocks",
|
||||||
|
"ZImageModularPipeline",
|
||||||
|
]
|
||||||
_import_structure["components_manager"] = ["ComponentsManager"]
|
_import_structure["components_manager"] = ["ComponentsManager"]
|
||||||
|
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
@@ -71,6 +79,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
else:
|
else:
|
||||||
from .components_manager import ComponentsManager
|
from .components_manager import ComponentsManager
|
||||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||||
|
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||||
from .modular_pipeline import (
|
from .modular_pipeline import (
|
||||||
AutoPipelineBlocks,
|
AutoPipelineBlocks,
|
||||||
BlockState,
|
BlockState,
|
||||||
@@ -91,6 +100,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
)
|
)
|
||||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||||
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
|
||||||
|
from .z_image import ZImageAutoBlocks, ZImageModularPipeline
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
|||||||
111
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
111
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["encoders"] = [
|
||||||
|
"Flux2TextEncoderStep",
|
||||||
|
"Flux2RemoteTextEncoderStep",
|
||||||
|
"Flux2VaeEncoderStep",
|
||||||
|
]
|
||||||
|
_import_structure["before_denoise"] = [
|
||||||
|
"Flux2SetTimestepsStep",
|
||||||
|
"Flux2PrepareLatentsStep",
|
||||||
|
"Flux2RoPEInputsStep",
|
||||||
|
"Flux2PrepareImageLatentsStep",
|
||||||
|
]
|
||||||
|
_import_structure["denoise"] = [
|
||||||
|
"Flux2LoopDenoiser",
|
||||||
|
"Flux2LoopAfterDenoiser",
|
||||||
|
"Flux2DenoiseLoopWrapper",
|
||||||
|
"Flux2DenoiseStep",
|
||||||
|
]
|
||||||
|
_import_structure["decoders"] = ["Flux2DecodeStep"]
|
||||||
|
_import_structure["inputs"] = [
|
||||||
|
"Flux2ProcessImagesInputStep",
|
||||||
|
"Flux2TextInputStep",
|
||||||
|
]
|
||||||
|
_import_structure["modular_blocks"] = [
|
||||||
|
"ALL_BLOCKS",
|
||||||
|
"AUTO_BLOCKS",
|
||||||
|
"REMOTE_AUTO_BLOCKS",
|
||||||
|
"TEXT2IMAGE_BLOCKS",
|
||||||
|
"IMAGE_CONDITIONED_BLOCKS",
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2AutoVaeEncoderStep",
|
||||||
|
"Flux2BeforeDenoiseStep",
|
||||||
|
"Flux2VaeEncoderSequentialStep",
|
||||||
|
]
|
||||||
|
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .before_denoise import (
|
||||||
|
Flux2PrepareImageLatentsStep,
|
||||||
|
Flux2PrepareLatentsStep,
|
||||||
|
Flux2RoPEInputsStep,
|
||||||
|
Flux2SetTimestepsStep,
|
||||||
|
)
|
||||||
|
from .decoders import Flux2DecodeStep
|
||||||
|
from .denoise import (
|
||||||
|
Flux2DenoiseLoopWrapper,
|
||||||
|
Flux2DenoiseStep,
|
||||||
|
Flux2LoopAfterDenoiser,
|
||||||
|
Flux2LoopDenoiser,
|
||||||
|
)
|
||||||
|
from .encoders import (
|
||||||
|
Flux2RemoteTextEncoderStep,
|
||||||
|
Flux2TextEncoderStep,
|
||||||
|
Flux2VaeEncoderStep,
|
||||||
|
)
|
||||||
|
from .inputs import (
|
||||||
|
Flux2ProcessImagesInputStep,
|
||||||
|
Flux2TextInputStep,
|
||||||
|
)
|
||||||
|
from .modular_blocks import (
|
||||||
|
ALL_BLOCKS,
|
||||||
|
AUTO_BLOCKS,
|
||||||
|
IMAGE_CONDITIONED_BLOCKS,
|
||||||
|
REMOTE_AUTO_BLOCKS,
|
||||||
|
TEXT2IMAGE_BLOCKS,
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2AutoVaeEncoderStep,
|
||||||
|
Flux2BeforeDenoiseStep,
|
||||||
|
Flux2VaeEncoderSequentialStep,
|
||||||
|
)
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
508
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
508
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...models import Flux2Transformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||||
|
"""Compute empirical mu for Flux2 timestep scheduling."""
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_inference_steps", default=50),
|
||||||
|
InputParam("timesteps"),
|
||||||
|
InputParam("sigmas"),
|
||||||
|
InputParam("guidance_scale", default=4.0),
|
||||||
|
InputParam("latents", type_hint=torch.Tensor),
|
||||||
|
InputParam("num_images_per_prompt", default=1),
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||||
|
OutputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of denoising steps to perform at inference time",
|
||||||
|
),
|
||||||
|
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
scheduler = components.scheduler
|
||||||
|
|
||||||
|
height = block_state.height or components.default_height
|
||||||
|
width = block_state.width or components.default_width
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
|
||||||
|
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||||
|
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||||
|
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||||
|
|
||||||
|
num_inference_steps = block_state.num_inference_steps
|
||||||
|
sigmas = block_state.sigmas
|
||||||
|
timesteps = block_state.timesteps
|
||||||
|
|
||||||
|
if timesteps is None and sigmas is None:
|
||||||
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||||
|
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||||
|
sigmas = None
|
||||||
|
|
||||||
|
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||||
|
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
block_state.device,
|
||||||
|
timesteps=timesteps,
|
||||||
|
sigmas=sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
block_state.timesteps = timesteps
|
||||||
|
block_state.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||||
|
guidance = guidance.expand(batch_size)
|
||||||
|
block_state.guidance = guidance
|
||||||
|
|
||||||
|
components.scheduler.set_begin_index(0)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||||
|
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||||
|
),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||||
|
),
|
||||||
|
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(components, block_state):
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
|
||||||
|
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_latent_ids(latents: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Latent tensor of shape (B, C, H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position IDs tensor of shape (B, H*W, 4)
|
||||||
|
"""
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
|
||||||
|
t = torch.arange(1)
|
||||||
|
h = torch.arange(height)
|
||||||
|
w = torch.arange(width)
|
||||||
|
l = torch.arange(1)
|
||||||
|
|
||||||
|
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||||
|
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||||
|
|
||||||
|
return latent_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_latents(latents):
|
||||||
|
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||||
|
batch_size, num_channels, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_latents(
|
||||||
|
comp,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.height = block_state.height or components.default_height
|
||||||
|
block_state.width = block_state.width or components.default_width
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
block_state.num_channels_latents = components.num_channels_latents
|
||||||
|
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
components,
|
||||||
|
batch_size,
|
||||||
|
block_state.num_channels_latents,
|
||||||
|
block_state.height,
|
||||||
|
block_state.width,
|
||||||
|
block_state.dtype,
|
||||||
|
block_state.device,
|
||||||
|
block_state.generator,
|
||||||
|
block_state.latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_ids = self._prepare_latent_ids(latents)
|
||||||
|
latent_ids = latent_ids.to(block_state.device)
|
||||||
|
|
||||||
|
latents = self._pack_latents(latents)
|
||||||
|
|
||||||
|
block_state.latents = latents
|
||||||
|
block_state.latent_ids = latent_ids
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(name="prompt_embeds", required=True),
|
||||||
|
InputParam(name="latent_ids"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
name="txt_ids",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
name="latent_ids",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||||
|
"""Prepare 4D position IDs for text tokens."""
|
||||||
|
B, L, _ = x.shape
|
||||||
|
out_ids = []
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||||
|
h = torch.arange(1)
|
||||||
|
w = torch.arange(1)
|
||||||
|
seq_l = torch.arange(L)
|
||||||
|
|
||||||
|
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||||
|
out_ids.append(coords)
|
||||||
|
|
||||||
|
return torch.stack(out_ids)
|
||||||
|
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
prompt_embeds = block_state.prompt_embeds
|
||||||
|
device = prompt_embeds.device
|
||||||
|
|
||||||
|
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||||
|
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image_latents", type_hint=List[torch.Tensor]),
|
||||||
|
InputParam("batch_size", required=True, type_hint=int),
|
||||||
|
InputParam("num_images_per_prompt", default=1, type_hint=int),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Packed image latents for conditioning",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"image_latent_ids",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for image latents",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
|
||||||
|
"""
|
||||||
|
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
|
||||||
|
scale: Factor used to define the time separation between latents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined coordinate tensor of shape (1, N_total, 4)
|
||||||
|
"""
|
||||||
|
if not isinstance(image_latents, list):
|
||||||
|
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||||
|
|
||||||
|
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||||
|
t_coords = [t.view(-1) for t in t_coords]
|
||||||
|
|
||||||
|
image_latent_ids = []
|
||||||
|
for x, t in zip(image_latents, t_coords):
|
||||||
|
x = x.squeeze(0)
|
||||||
|
_, height, width = x.shape
|
||||||
|
|
||||||
|
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||||
|
image_latent_ids.append(x_ids)
|
||||||
|
|
||||||
|
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||||
|
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
return image_latent_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_latents(latents):
|
||||||
|
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||||
|
batch_size, num_channels, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
image_latents = block_state.image_latents
|
||||||
|
|
||||||
|
if image_latents is None:
|
||||||
|
block_state.image_latents = None
|
||||||
|
block_state.image_latent_ids = None
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
|
||||||
|
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||||
|
|
||||||
|
packed_latents = []
|
||||||
|
for latent in image_latents:
|
||||||
|
packed = self._pack_latents(latent)
|
||||||
|
packed = packed.squeeze(0)
|
||||||
|
packed_latents.append(packed)
|
||||||
|
|
||||||
|
image_latents = torch.cat(packed_latents, dim=0)
|
||||||
|
image_latents = image_latents.unsqueeze(0)
|
||||||
|
|
||||||
|
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||||
|
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||||
|
image_latent_ids = image_latent_ids.to(device)
|
||||||
|
|
||||||
|
block_state.image_latents = image_latents
|
||||||
|
block_state.image_latent_ids = image_latent_ids
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...models import AutoencoderKLFlux2
|
||||||
|
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
Flux2ImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("output_type", default="pil"),
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The denoised latents from the denoising step",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"latent_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for the latents, used for unpacking",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"images",
|
||||||
|
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||||
|
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Unpack latents using position IDs to scatter tokens into place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Packed latents tensor of shape (B, seq_len, C)
|
||||||
|
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unpacked latents tensor of shape (B, C, H, W)
|
||||||
|
"""
|
||||||
|
x_list = []
|
||||||
|
for data, pos in zip(x, x_ids):
|
||||||
|
_, ch = data.shape # noqa: F841
|
||||||
|
h_ids = pos[:, 1].to(torch.int64)
|
||||||
|
w_ids = pos[:, 2].to(torch.int64)
|
||||||
|
|
||||||
|
h = torch.max(h_ids) + 1
|
||||||
|
w = torch.max(w_ids) + 1
|
||||||
|
|
||||||
|
flat_ids = h_ids * w + w_ids
|
||||||
|
|
||||||
|
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||||
|
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||||
|
|
||||||
|
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||||
|
x_list.append(out)
|
||||||
|
|
||||||
|
return torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpatchify_latents(latents):
|
||||||
|
"""Convert patchified latents back to regular format."""
|
||||||
|
batch_size, num_channels_latents, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||||
|
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
vae = components.vae
|
||||||
|
|
||||||
|
if block_state.output_type == "latent":
|
||||||
|
block_state.images = block_state.latents
|
||||||
|
else:
|
||||||
|
latents = block_state.latents
|
||||||
|
latent_ids = block_state.latent_ids
|
||||||
|
|
||||||
|
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||||
|
|
||||||
|
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||||
|
latents.device, latents.dtype
|
||||||
|
)
|
||||||
|
latents = latents * latents_bn_std + latents_bn_mean
|
||||||
|
|
||||||
|
latents = self._unpatchify_latents(latents)
|
||||||
|
|
||||||
|
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||||
|
block_state.images = components.image_processor.postprocess(
|
||||||
|
block_state.images, output_type=block_state.output_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...models import Flux2Transformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import is_torch_xla_available, logging
|
||||||
|
from ..modular_pipeline import (
|
||||||
|
BlockState,
|
||||||
|
LoopSequentialPipelineBlocks,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
PipelineState,
|
||||||
|
)
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("joint_attention_kwargs"),
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"image_latent_ids",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"guidance",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Guidance scale as a tensor",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from Mistral3",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"txt_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs for text tokens (T, H, W, L)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"latent_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||||
|
) -> PipelineState:
|
||||||
|
latents = block_state.latents
|
||||||
|
latent_model_input = latents.to(components.transformer.dtype)
|
||||||
|
img_ids = block_state.latent_ids
|
||||||
|
|
||||||
|
image_latents = getattr(block_state, "image_latents", None)
|
||||||
|
if image_latents is not None:
|
||||||
|
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||||
|
image_latent_ids = block_state.image_latent_ids
|
||||||
|
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||||
|
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
noise_pred = components.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=block_state.guidance,
|
||||||
|
encoder_hidden_states=block_state.prompt_embeds,
|
||||||
|
txt_ids=block_state.txt_ids,
|
||||||
|
img_ids=img_ids,
|
||||||
|
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
|
block_state.noise_pred = noise_pred
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that updates the latents after denoising. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [InputParam("generator")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||||
|
latents_dtype = block_state.latents.dtype
|
||||||
|
block_state.latents = components.scheduler.step(
|
||||||
|
block_state.noise_pred,
|
||||||
|
t,
|
||||||
|
block_state.latents,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if block_state.latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
block_state.latents = block_state.latents.to(latents_dtype)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Pipeline block that iteratively denoises the latents over `timesteps`. "
|
||||||
|
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"timesteps",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The timesteps to use for the denoising process.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.num_warmup_steps = max(
|
||||||
|
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(block_state.timesteps):
|
||||||
|
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||||
|
|
||||||
|
if i == len(block_state.timesteps) - 1 or (
|
||||||
|
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||||
|
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
|
||||||
|
block_names = ["denoiser", "after_denoiser"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||||
|
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||||
|
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||||
|
" - `Flux2LoopDenoiser`\n"
|
||||||
|
" - `Flux2LoopAfterDenoiser`\n"
|
||||||
|
"This block supports both text-to-image and image-conditioned generation."
|
||||||
|
)
|
||||||
349
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
349
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
from ...models import AutoencoderKLFlux2
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def format_text_input(prompts: List[str], system_message: str = None):
|
||||||
|
"""Format prompts for Mistral3 chat template."""
|
||||||
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||||
|
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": system_message}],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
]
|
||||||
|
for prompt in cleaned_txt
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
|
||||||
|
ComponentSpec("tokenizer", AutoProcessor),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||||
|
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||||
|
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||||
|
InputParam("joint_attention_kwargs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from Mistral3 used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
prompt = block_state.prompt
|
||||||
|
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||||
|
"Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_mistral_3_prompt_embeds(
|
||||||
|
text_encoder: Mistral3ForConditionalGeneration,
|
||||||
|
tokenizer: AutoProcessor,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
# fmt: off
|
||||||
|
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||||
|
# fmt: on
|
||||||
|
hidden_states_layers: Tuple[int] = (10, 20, 30),
|
||||||
|
):
|
||||||
|
dtype = text_encoder.dtype if dtype is None else dtype
|
||||||
|
device = text_encoder.device if device is None else device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||||
|
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages_batch,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(device)
|
||||||
|
|
||||||
|
output = text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||||
|
out = out.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||||
|
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
if block_state.prompt_embeds is not None:
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
prompt = block_state.prompt
|
||||||
|
if prompt is None:
|
||||||
|
prompt = ""
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
|
||||||
|
text_encoder=components.text_encoder,
|
||||||
|
tokenizer=components.tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
device=block_state.device,
|
||||||
|
max_sequence_length=block_state.max_sequence_length,
|
||||||
|
system_message=self.DEFAULT_SYSTEM_MESSAGE,
|
||||||
|
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generates text embeddings using a remote API endpoint"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from remote API used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
prompt = block_state.prompt
|
||||||
|
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||||
|
"Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
import io
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import get_token
|
||||||
|
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
if block_state.prompt_embeds is not None:
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
prompt = block_state.prompt
|
||||||
|
if prompt is None:
|
||||||
|
prompt = ""
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.REMOTE_URL,
|
||||||
|
json={"prompt": prompt},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {get_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("vae", AutoencoderKLFlux2)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("condition_images", type_hint=List[torch.Tensor]),
|
||||||
|
InputParam("generator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
description="List of latent representations for each reference image",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patchify_latents(latents):
|
||||||
|
"""Convert latents to patchified format for Flux2."""
|
||||||
|
batch_size, num_channels_latents, height, width = latents.shape
|
||||||
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
|
||||||
|
"""Encode a single image using Flux2 VAE with batch norm normalization."""
|
||||||
|
if image.ndim != 4:
|
||||||
|
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||||
|
|
||||||
|
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
|
||||||
|
image_latents = self._patchify_latents(image_latents)
|
||||||
|
|
||||||
|
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
|
||||||
|
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
|
||||||
|
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
condition_images = block_state.condition_images
|
||||||
|
|
||||||
|
if condition_images is None:
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
dtype = components.vae.dtype
|
||||||
|
|
||||||
|
image_latents = []
|
||||||
|
for image in condition_images:
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
latent = self._encode_vae_image(
|
||||||
|
vae=components.vae,
|
||||||
|
image=image,
|
||||||
|
generator=block_state.generator,
|
||||||
|
)
|
||||||
|
image_latents.append(latent)
|
||||||
|
|
||||||
|
block_state.image_latents = image_latents
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
160
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
160
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"This step:\n"
|
||||||
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||||
|
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_images_per_prompt", default=1),
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"batch_size",
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"dtype",
|
||||||
|
type_hint=torch.dtype,
|
||||||
|
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
description="Text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||||
|
block_state.dtype = block_state.prompt_embeds.dtype
|
||||||
|
|
||||||
|
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||||
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
Flux2ImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image"),
|
||||||
|
InputParam("height"),
|
||||||
|
InputParam("width"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
images = block_state.image
|
||||||
|
|
||||||
|
if images is None:
|
||||||
|
block_state.condition_images = None
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
condition_images = []
|
||||||
|
for img in images:
|
||||||
|
components.image_processor.check_image_input(img)
|
||||||
|
|
||||||
|
image_width, image_height = img.size
|
||||||
|
if image_width * image_height > 1024 * 1024:
|
||||||
|
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||||
|
image_width, image_height = img.size
|
||||||
|
|
||||||
|
multiple_of = components.vae_scale_factor * 2
|
||||||
|
image_width = (image_width // multiple_of) * multiple_of
|
||||||
|
image_height = (image_height // multiple_of) * multiple_of
|
||||||
|
condition_img = components.image_processor.preprocess(
|
||||||
|
img, height=image_height, width=image_width, resize_mode="crop"
|
||||||
|
)
|
||||||
|
condition_images.append(condition_img)
|
||||||
|
|
||||||
|
if block_state.height is None:
|
||||||
|
block_state.height = image_height
|
||||||
|
if block_state.width is None:
|
||||||
|
block_state.width = image_width
|
||||||
|
|
||||||
|
block_state.condition_images = condition_images
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
166
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
166
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
|
from ..modular_pipeline_utils import InsertableDict
|
||||||
|
from .before_denoise import (
|
||||||
|
Flux2PrepareImageLatentsStep,
|
||||||
|
Flux2PrepareLatentsStep,
|
||||||
|
Flux2RoPEInputsStep,
|
||||||
|
Flux2SetTimestepsStep,
|
||||||
|
)
|
||||||
|
from .decoders import Flux2DecodeStep
|
||||||
|
from .denoise import Flux2DenoiseStep
|
||||||
|
from .encoders import (
|
||||||
|
Flux2RemoteTextEncoderStep,
|
||||||
|
Flux2TextEncoderStep,
|
||||||
|
Flux2VaeEncoderStep,
|
||||||
|
)
|
||||||
|
from .inputs import (
|
||||||
|
Flux2ProcessImagesInputStep,
|
||||||
|
Flux2TextInputStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
Flux2VaeEncoderBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("preprocess", Flux2ProcessImagesInputStep()),
|
||||||
|
("encode", Flux2VaeEncoderStep()),
|
||||||
|
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = Flux2VaeEncoderBlocks.values()
|
||||||
|
block_names = Flux2VaeEncoderBlocks.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [Flux2VaeEncoderSequentialStep]
|
||||||
|
block_names = ["img_conditioning"]
|
||||||
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||||
|
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||||
|
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||||
|
" - If `image` is not provided, step will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||||
|
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||||
|
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||||
|
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = AUTO_BLOCKS.values()
|
||||||
|
block_names = AUTO_BLOCKS.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
|
||||||
|
"- For text-to-image generation, all you need to provide is `prompt`.\n"
|
||||||
|
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("preprocess_images", Flux2ProcessImagesInputStep()),
|
||||||
|
("vae_encoder", Flux2VaeEncoderStep()),
|
||||||
|
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ALL_BLOCKS = {
|
||||||
|
"text2image": TEXT2IMAGE_BLOCKS,
|
||||||
|
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
|
||||||
|
"auto": AUTO_BLOCKS,
|
||||||
|
"remote": REMOTE_AUTO_BLOCKS,
|
||||||
|
}
|
||||||
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...loaders import Flux2LoraLoaderMixin
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||||
|
"""
|
||||||
|
A ModularPipeline for Flux2.
|
||||||
|
|
||||||
|
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_blocks_name = "Flux2AutoBlocks"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_height(self):
|
||||||
|
return self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_width(self):
|
||||||
|
return self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sample_size(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor(self):
|
||||||
|
vae_scale_factor = 8
|
||||||
|
if getattr(self, "vae", None) is not None:
|
||||||
|
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_latents(self):
|
||||||
|
num_channels_latents = 32
|
||||||
|
if getattr(self, "transformer", None):
|
||||||
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
|
return num_channels_latents
|
||||||
@@ -58,9 +58,11 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
|||||||
("wan", "WanModularPipeline"),
|
("wan", "WanModularPipeline"),
|
||||||
("flux", "FluxModularPipeline"),
|
("flux", "FluxModularPipeline"),
|
||||||
("flux-kontext", "FluxKontextModularPipeline"),
|
("flux-kontext", "FluxKontextModularPipeline"),
|
||||||
|
("flux2", "Flux2ModularPipeline"),
|
||||||
("qwenimage", "QwenImageModularPipeline"),
|
("qwenimage", "QwenImageModularPipeline"),
|
||||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||||
|
("z-image", "ZImageModularPipeline"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1585,7 +1587,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
for name, config_spec in self._config_specs.items():
|
for name, config_spec in self._config_specs.items():
|
||||||
default_configs[name] = config_spec.default
|
default_configs[name] = config_spec.default
|
||||||
self.register_to_config(**default_configs)
|
self.register_to_config(**default_configs)
|
||||||
|
|
||||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -610,7 +610,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
|||||||
block_state = self.get_block_state(state)
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
# for edit, image size can be different from the target size (height/width)
|
# for edit, image size can be different from the target size (height/width)
|
||||||
|
|
||||||
block_state.img_shapes = [
|
block_state.img_shapes = [
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
@@ -640,6 +639,37 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||||
|
model_name = "qwenimage-edit-plus"
|
||||||
|
|
||||||
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
block_state.img_shapes = [
|
||||||
|
[
|
||||||
|
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||||
|
*[
|
||||||
|
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||||
|
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||||
|
],
|
||||||
|
]
|
||||||
|
] * block_state.batch_size
|
||||||
|
|
||||||
|
block_state.txt_seq_lens = (
|
||||||
|
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||||
|
)
|
||||||
|
block_state.negative_txt_seq_lens = (
|
||||||
|
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||||
|
if block_state.negative_prompt_embeds_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
## ControlNet inputs for denoiser
|
## ControlNet inputs for denoiser
|
||||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage"
|
||||||
|
|||||||
@@ -330,7 +330,7 @@ class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
|
|||||||
output_name: str = "resized_image",
|
output_name: str = "resized_image",
|
||||||
vae_image_output_name: str = "vae_image",
|
vae_image_output_name: str = "vae_image",
|
||||||
):
|
):
|
||||||
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
|
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.
|
||||||
|
|
||||||
This block resizes an input image or a list input images and exposes the resized result under configurable
|
This block resizes an input image or a list input images and exposes the resized result under configurable
|
||||||
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
|
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
|
||||||
@@ -809,9 +809,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def intermediate_outputs(self) -> List[OutputParam]:
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
return [
|
return [OutputParam(name="processed_image")]
|
||||||
OutputParam(name="processed_image"),
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_inputs(height, width, vae_scale_factor):
|
def check_inputs(height, width, vae_scale_factor):
|
||||||
@@ -851,7 +849,10 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
|||||||
|
|
||||||
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
||||||
model_name = "qwenimage-edit-plus"
|
model_name = "qwenimage-edit-plus"
|
||||||
vae_image_size = 1024 * 1024
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vae_image_size = 1024 * 1024
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
@@ -868,6 +869,7 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
|||||||
if block_state.vae_image is None and block_state.image is None:
|
if block_state.vae_image is None and block_state.image is None:
|
||||||
raise ValueError("`vae_image` and `image` cannot be None at the same time")
|
raise ValueError("`vae_image` and `image` cannot be None at the same time")
|
||||||
|
|
||||||
|
vae_image_sizes = None
|
||||||
if block_state.vae_image is None:
|
if block_state.vae_image is None:
|
||||||
image = block_state.image
|
image = block_state.image
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
@@ -879,12 +881,19 @@ class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
|
|||||||
image=image, height=height, width=width
|
image=image, height=height, width=width
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
width, height = block_state.vae_image[0].size
|
# QwenImage Edit Plus can allow multiple input images with varied resolutions
|
||||||
image = block_state.vae_image
|
processed_images = []
|
||||||
|
vae_image_sizes = []
|
||||||
|
for img in block_state.vae_image:
|
||||||
|
width, height = img.size
|
||||||
|
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
|
||||||
|
vae_image_sizes.append((vae_width, vae_height))
|
||||||
|
processed_images.append(
|
||||||
|
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
|
||||||
|
)
|
||||||
|
block_state.processed_image = processed_images
|
||||||
|
|
||||||
block_state.processed_image = components.image_processor.preprocess(
|
block_state.vae_image_sizes = vae_image_sizes
|
||||||
image=image, height=height, width=width
|
|
||||||
)
|
|
||||||
|
|
||||||
self.set_block_state(state, block_state)
|
self.set_block_state(state, block_state)
|
||||||
return components, state
|
return components, state
|
||||||
@@ -926,17 +935,12 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def expected_components(self) -> List[ComponentSpec]:
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
components = [
|
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
|
||||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
|
||||||
]
|
|
||||||
return components
|
return components
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> List[InputParam]:
|
def inputs(self) -> List[InputParam]:
|
||||||
inputs = [
|
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
|
||||||
InputParam(self._image_input_name, required=True),
|
|
||||||
InputParam("generator"),
|
|
||||||
]
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -974,6 +978,50 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
|
||||||
|
model_name = "qwenimage-edit-plus"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
# Each reference image latent can have varied resolutions hence we return this as a list.
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
self._image_latents_output_name,
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
description="The latents representing the reference image(s).",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
dtype = components.vae.dtype
|
||||||
|
|
||||||
|
image = getattr(block_state, self._image_input_name)
|
||||||
|
|
||||||
|
# Encode image into latents
|
||||||
|
image_latents = []
|
||||||
|
for img in image:
|
||||||
|
image_latents.append(
|
||||||
|
encode_vae_image(
|
||||||
|
image=img,
|
||||||
|
vae=components.vae,
|
||||||
|
generator=block_state.generator,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
latent_channels=components.num_channels_latents,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(block_state, self._image_latents_output_name, image_latents)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage"
|
||||||
|
|
||||||
|
|||||||
@@ -224,11 +224,7 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
|||||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage"
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||||
self,
|
|
||||||
image_latent_inputs: List[str] = ["image_latents"],
|
|
||||||
additional_batch_inputs: List[str] = [],
|
|
||||||
):
|
|
||||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||||
|
|
||||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||||
@@ -372,6 +368,76 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
|||||||
return components, state
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||||
|
model_name = "qwenimage-edit-plus"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||||
|
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||||
|
for image_latent_input_name in self._image_latent_inputs:
|
||||||
|
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||||
|
if image_latent_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Each image latent can have different size in QwenImage Edit Plus.
|
||||||
|
image_heights = []
|
||||||
|
image_widths = []
|
||||||
|
packed_image_latent_tensors = []
|
||||||
|
|
||||||
|
for img_latent_tensor in image_latent_tensor:
|
||||||
|
# 1. Calculate height/width from latents
|
||||||
|
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||||
|
image_heights.append(height)
|
||||||
|
image_widths.append(width)
|
||||||
|
|
||||||
|
# 2. Patchify the image latent tensor
|
||||||
|
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||||
|
|
||||||
|
# 3. Expand batch size
|
||||||
|
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||||
|
input_name=image_latent_input_name,
|
||||||
|
input_tensor=img_latent_tensor,
|
||||||
|
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||||
|
batch_size=block_state.batch_size,
|
||||||
|
)
|
||||||
|
packed_image_latent_tensors.append(img_latent_tensor)
|
||||||
|
|
||||||
|
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||||
|
block_state.image_height = image_heights
|
||||||
|
block_state.image_width = image_widths
|
||||||
|
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||||
|
|
||||||
|
block_state.height = block_state.height or image_heights[-1]
|
||||||
|
block_state.width = block_state.width or image_widths[-1]
|
||||||
|
|
||||||
|
# Process additional batch inputs (only batch expansion)
|
||||||
|
for input_name in self._additional_batch_inputs:
|
||||||
|
input_tensor = getattr(block_state, input_name)
|
||||||
|
if input_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only expand batch size
|
||||||
|
input_tensor = repeat_tensor_to_batch_size(
|
||||||
|
input_name=input_name,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||||
|
batch_size=block_state.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(block_state, input_name, input_tensor)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage"
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from ..modular_pipeline_utils import InsertableDict
|
|||||||
from .before_denoise import (
|
from .before_denoise import (
|
||||||
QwenImageControlNetBeforeDenoiserStep,
|
QwenImageControlNetBeforeDenoiserStep,
|
||||||
QwenImageCreateMaskLatentsStep,
|
QwenImageCreateMaskLatentsStep,
|
||||||
|
QwenImageEditPlusRoPEInputsStep,
|
||||||
QwenImageEditRoPEInputsStep,
|
QwenImageEditRoPEInputsStep,
|
||||||
QwenImagePrepareLatentsStep,
|
QwenImagePrepareLatentsStep,
|
||||||
QwenImagePrepareLatentsWithStrengthStep,
|
QwenImagePrepareLatentsWithStrengthStep,
|
||||||
@@ -40,6 +41,7 @@ from .encoders import (
|
|||||||
QwenImageEditPlusProcessImagesInputStep,
|
QwenImageEditPlusProcessImagesInputStep,
|
||||||
QwenImageEditPlusResizeDynamicStep,
|
QwenImageEditPlusResizeDynamicStep,
|
||||||
QwenImageEditPlusTextEncoderStep,
|
QwenImageEditPlusTextEncoderStep,
|
||||||
|
QwenImageEditPlusVaeEncoderDynamicStep,
|
||||||
QwenImageEditResizeDynamicStep,
|
QwenImageEditResizeDynamicStep,
|
||||||
QwenImageEditTextEncoderStep,
|
QwenImageEditTextEncoderStep,
|
||||||
QwenImageInpaintProcessImagesInputStep,
|
QwenImageInpaintProcessImagesInputStep,
|
||||||
@@ -47,7 +49,12 @@ from .encoders import (
|
|||||||
QwenImageTextEncoderStep,
|
QwenImageTextEncoderStep,
|
||||||
QwenImageVaeEncoderDynamicStep,
|
QwenImageVaeEncoderDynamicStep,
|
||||||
)
|
)
|
||||||
from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep
|
from .inputs import (
|
||||||
|
QwenImageControlNetInputsStep,
|
||||||
|
QwenImageEditPlusInputsDynamicStep,
|
||||||
|
QwenImageInputsDynamicStep,
|
||||||
|
QwenImageTextInputsStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -904,13 +911,13 @@ QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
|
|||||||
[
|
[
|
||||||
("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
|
("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
|
||||||
("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
|
("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
|
||||||
("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
|
("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||||
model_name = "qwenimage"
|
model_name = "qwenimage-edit-plus"
|
||||||
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
|
block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
|
||||||
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
|
block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
|
||||||
|
|
||||||
@@ -919,25 +926,62 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
|||||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||||
|
|
||||||
|
|
||||||
|
#### QwenImage Edit Plus input blocks
|
||||||
|
QwenImageEditPlusInputBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings
|
||||||
|
(
|
||||||
|
"additional_inputs",
|
||||||
|
QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "qwenimage-edit-plus"
|
||||||
|
block_classes = QwenImageEditPlusInputBlocks.values()
|
||||||
|
block_names = QwenImageEditPlusInputBlocks.keys()
|
||||||
|
|
||||||
|
|
||||||
#### QwenImage Edit Plus presets
|
#### QwenImage Edit Plus presets
|
||||||
EDIT_PLUS_BLOCKS = InsertableDict(
|
EDIT_PLUS_BLOCKS = InsertableDict(
|
||||||
[
|
[
|
||||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||||
("input", QwenImageEditInputStep()),
|
("input", QwenImageEditPlusInputStep()),
|
||||||
("prepare_latents", QwenImagePrepareLatentsStep()),
|
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||||
("set_timesteps", QwenImageSetTimestepsStep()),
|
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||||
("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
|
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||||
("denoise", QwenImageEditDenoiseStep()),
|
("denoise", QwenImageEditDenoiseStep()),
|
||||||
("decode", QwenImageDecodeStep()),
|
("decode", QwenImageDecodeStep()),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("prepare_latents", QwenImagePrepareLatentsStep()),
|
||||||
|
("set_timesteps", QwenImageSetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "qwenimage-edit-plus"
|
||||||
|
block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values()
|
||||||
|
block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task."
|
||||||
|
|
||||||
|
|
||||||
# auto before_denoise step for edit tasks
|
# auto before_denoise step for edit tasks
|
||||||
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||||
model_name = "qwenimage-edit-plus"
|
model_name = "qwenimage-edit-plus"
|
||||||
block_classes = [QwenImageEditBeforeDenoiseStep]
|
block_classes = [QwenImageEditPlusBeforeDenoiseStep]
|
||||||
block_names = ["edit"]
|
block_names = ["edit"]
|
||||||
block_trigger_inputs = ["image_latents"]
|
block_trigger_inputs = ["image_latents"]
|
||||||
|
|
||||||
@@ -946,7 +990,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
|||||||
return (
|
return (
|
||||||
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
"Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
|
||||||
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
|
+ "This is an auto pipeline block that works for edit (img2img) task.\n"
|
||||||
+ " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
+ " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
|
||||||
+ " - if `image_latents` is not provided, step will be skipped."
|
+ " - if `image_latents` is not provided, step will be skipped."
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -955,9 +999,7 @@ class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
|||||||
|
|
||||||
|
|
||||||
class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||||
block_classes = [
|
block_classes = [QwenImageEditPlusVaeEncoderStep]
|
||||||
QwenImageEditPlusVaeEncoderStep,
|
|
||||||
]
|
|
||||||
block_names = ["edit"]
|
block_names = ["edit"]
|
||||||
block_trigger_inputs = ["image"]
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
@@ -974,10 +1016,25 @@ class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
|
|||||||
## 3.3 QwenImage-Edit/auto blocks & presets
|
## 3.3 QwenImage-Edit/auto blocks & presets
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [QwenImageEditPlusInputStep]
|
||||||
|
block_names = ["edit"]
|
||||||
|
block_trigger_inputs = ["image_latents"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Input step that prepares the inputs for the edit denoising step.\n"
|
||||||
|
+ " It is an auto pipeline block that works for edit task.\n"
|
||||||
|
+ " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n"
|
||||||
|
+ " - if `image_latents` is not provided, step will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||||
model_name = "qwenimage-edit-plus"
|
model_name = "qwenimage-edit-plus"
|
||||||
block_classes = [
|
block_classes = [
|
||||||
QwenImageEditAutoInputStep,
|
QwenImageEditPlusAutoInputStep,
|
||||||
QwenImageEditPlusAutoBeforeDenoiseStep,
|
QwenImageEditPlusAutoBeforeDenoiseStep,
|
||||||
QwenImageEditAutoDenoiseStep,
|
QwenImageEditAutoDenoiseStep,
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -530,6 +530,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
|||||||
|
|
||||||
device = components._execution_device
|
device = components._execution_device
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
vae_dtype = components.vae.dtype
|
||||||
|
|
||||||
height = block_state.height or components.default_height
|
height = block_state.height or components.default_height
|
||||||
width = block_state.width or components.default_width
|
width = block_state.width or components.default_width
|
||||||
@@ -555,7 +556,7 @@ class WanVaeImageEncoderStep(ModularPipelineBlocks):
|
|||||||
vae=components.vae,
|
vae=components.vae,
|
||||||
generator=block_state.generator,
|
generator=block_state.generator,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=vae_dtype,
|
||||||
latent_channels=components.num_channels_latents,
|
latent_channels=components.num_channels_latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -627,6 +628,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
|||||||
|
|
||||||
device = components._execution_device
|
device = components._execution_device
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
vae_dtype = components.vae.dtype
|
||||||
|
|
||||||
height = block_state.height or components.default_height
|
height = block_state.height or components.default_height
|
||||||
width = block_state.width or components.default_width
|
width = block_state.width or components.default_width
|
||||||
@@ -659,7 +661,7 @@ class WanFirstLastFrameVaeImageEncoderStep(ModularPipelineBlocks):
|
|||||||
vae=components.vae,
|
vae=components.vae,
|
||||||
generator=block_state.generator,
|
generator=block_state.generator,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=vae_dtype,
|
||||||
latent_channels=components.num_channels_latents,
|
latent_channels=components.num_channels_latents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
57
src/diffusers/modular_pipelines/z_image/__init__.py
Normal file
57
src/diffusers/modular_pipelines/z_image/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
|
||||||
|
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
|
||||||
|
_import_structure["modular_blocks"] = [
|
||||||
|
"ALL_BLOCKS",
|
||||||
|
"ZImageAutoBlocks",
|
||||||
|
]
|
||||||
|
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .decoders import ZImageVaeDecoderStep
|
||||||
|
from .encoders import ZImageTextEncoderStep
|
||||||
|
from .modular_blocks import (
|
||||||
|
ALL_BLOCKS,
|
||||||
|
ZImageAutoBlocks,
|
||||||
|
)
|
||||||
|
from .modular_pipeline import ZImageModularPipeline
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
621
src/diffusers/modular_pipelines/z_image/before_denoise.py
Normal file
621
src/diffusers/modular_pipelines/z_image/before_denoise.py
Normal file
@@ -0,0 +1,621 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...models import ZImageTransformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import ZImageModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
|
||||||
|
# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
|
||||||
|
# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
|
||||||
|
# configuration of guider is.
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_tensor_to_batch_size(
|
||||||
|
input_name: str,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
num_images_per_prompt: int = 1,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Repeat tensor elements to match the final batch size.
|
||||||
|
|
||||||
|
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
|
||||||
|
by repeating each element along dimension 0.
|
||||||
|
|
||||||
|
The input tensor must have batch size 1 or batch_size. The function will:
|
||||||
|
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
|
||||||
|
- If batch size equals batch_size: repeat each element num_images_per_prompt times
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_name (str): Name of the input tensor (used for error messages)
|
||||||
|
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||||
|
batch_size (int): The base batch size (number of prompts)
|
||||||
|
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||||
|
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||||
|
[4, 3]
|
||||||
|
|
||||||
|
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||||
|
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||||
|
- shape: [4, 3]
|
||||||
|
"""
|
||||||
|
# make sure input is a tensor
|
||||||
|
if not isinstance(input_tensor, torch.Tensor):
|
||||||
|
raise ValueError(f"`{input_name}` must be a tensor")
|
||||||
|
|
||||||
|
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||||
|
if input_tensor.shape[0] == 1:
|
||||||
|
repeat_by = batch_size * num_images_per_prompt
|
||||||
|
elif input_tensor.shape[0] == batch_size:
|
||||||
|
repeat_by = num_images_per_prompt
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# expand the tensor to match the batch_size * num_images_per_prompt
|
||||||
|
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||||
|
|
||||||
|
return input_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor_spatial: int) -> Tuple[int, int]:
|
||||||
|
"""Calculate image dimensions from latent tensor dimensions.
|
||||||
|
|
||||||
|
This function converts latent spatial dimensions to image spatial dimensions by multiplying the latent height/width
|
||||||
|
by the VAE scale factor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents (torch.Tensor): The latent tensor. Must have 4 dimensions.
|
||||||
|
Expected shapes: [batch, channels, height, width]
|
||||||
|
vae_scale_factor (int): The scale factor used by the VAE to compress image spatial dimension.
|
||||||
|
By default, it is 16
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: The calculated image dimensions as (height, width)
|
||||||
|
"""
|
||||||
|
latent_height, latent_width = latents.shape[2:]
|
||||||
|
height = latent_height * vae_scale_factor_spatial // 2
|
||||||
|
width = latent_width * vae_scale_factor_spatial // 2
|
||||||
|
|
||||||
|
return height, width
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||||
|
def calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
base_seq_len: int = 256,
|
||||||
|
max_seq_len: int = 4096,
|
||||||
|
base_shift: float = 0.5,
|
||||||
|
max_shift: float = 1.15,
|
||||||
|
):
|
||||||
|
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||||
|
b = base_shift - m * base_seq_len
|
||||||
|
mu = image_seq_len * m + b
|
||||||
|
return mu
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTextInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Input processing step that:\n"
|
||||||
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||||
|
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
|
||||||
|
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
|
||||||
|
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
|
||||||
|
"have a final batch_size of batch_size * num_images_per_prompt."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("transformer", ZImageTransformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_images_per_prompt", default=1),
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
description="Pre-generated text embeddings. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"batch_size",
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"dtype",
|
||||||
|
type_hint=torch.dtype,
|
||||||
|
description="Data type of model tensor inputs (determined by `transformer.dtype`)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_inputs(self, components, block_state):
|
||||||
|
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
|
||||||
|
if not isinstance(block_state.prompt_embeds, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"`prompt_embeds` must be a list when passed directly, but got {type(block_state.prompt_embeds)}."
|
||||||
|
)
|
||||||
|
if not isinstance(block_state.negative_prompt_embeds, list):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt_embeds` must be a list when passed directly, but got {type(block_state.negative_prompt_embeds)}."
|
||||||
|
)
|
||||||
|
if len(block_state.prompt_embeds) != len(block_state.negative_prompt_embeds):
|
||||||
|
raise ValueError(
|
||||||
|
"`prompt_embeds` and `negative_prompt_embeds` must have the same length when passed directly, but"
|
||||||
|
f" got: `prompt_embeds` {len(block_state.prompt_embeds)} != `negative_prompt_embeds`"
|
||||||
|
f" {len(block_state.negative_prompt_embeds)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
block_state.batch_size = len(block_state.prompt_embeds)
|
||||||
|
block_state.dtype = block_state.prompt_embeds[0].dtype
|
||||||
|
|
||||||
|
if block_state.num_images_per_prompt > 1:
|
||||||
|
prompt_embeds = [pe for pe in block_state.prompt_embeds for _ in range(block_state.num_images_per_prompt)]
|
||||||
|
block_state.prompt_embeds = prompt_embeds
|
||||||
|
|
||||||
|
if block_state.negative_prompt_embeds is not None:
|
||||||
|
negative_prompt_embeds = [
|
||||||
|
npe for npe in block_state.negative_prompt_embeds for _ in range(block_state.num_images_per_prompt)
|
||||||
|
]
|
||||||
|
block_state.negative_prompt_embeds = negative_prompt_embeds
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_latent_inputs: List[str] = ["image_latents"],
|
||||||
|
additional_batch_inputs: List[str] = [],
|
||||||
|
):
|
||||||
|
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||||
|
|
||||||
|
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||||
|
1. For encoded image latents, use it update height/width if None, and expands batch size
|
||||||
|
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||||
|
|
||||||
|
This is a dynamic block that allows you to configure which inputs to process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||||
|
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
|
||||||
|
a single string or list of strings. Defaults to ["image_latents"].
|
||||||
|
additional_batch_inputs (List[str], optional):
|
||||||
|
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||||
|
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||||
|
Defaults to [].
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Configure to process image_latents (default behavior) ZImageAdditionalInputsStep()
|
||||||
|
|
||||||
|
# Configure to process multiple image latent inputs
|
||||||
|
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||||
|
|
||||||
|
# Configure to process image latents and additional batch inputs ZImageAdditionalInputsStep(
|
||||||
|
image_latent_inputs=["image_latents"], additional_batch_inputs=["image_embeds"]
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
if not isinstance(image_latent_inputs, list):
|
||||||
|
image_latent_inputs = [image_latent_inputs]
|
||||||
|
if not isinstance(additional_batch_inputs, list):
|
||||||
|
additional_batch_inputs = [additional_batch_inputs]
|
||||||
|
|
||||||
|
self._image_latent_inputs = image_latent_inputs
|
||||||
|
self._additional_batch_inputs = additional_batch_inputs
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
# Functionality section
|
||||||
|
summary_section = (
|
||||||
|
"Input processing step that:\n"
|
||||||
|
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
|
||||||
|
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inputs info
|
||||||
|
inputs_info = ""
|
||||||
|
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||||
|
inputs_info = "\n\nConfigured inputs:"
|
||||||
|
if self._image_latent_inputs:
|
||||||
|
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||||
|
if self._additional_batch_inputs:
|
||||||
|
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||||
|
|
||||||
|
# Placement guidance
|
||||||
|
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||||
|
|
||||||
|
return summary_section + inputs_info + placement_section
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
inputs = [
|
||||||
|
InputParam(name="num_images_per_prompt", default=1),
|
||||||
|
InputParam(name="batch_size", required=True),
|
||||||
|
InputParam(name="height"),
|
||||||
|
InputParam(name="width"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add image latent inputs
|
||||||
|
for image_latent_input_name in self._image_latent_inputs:
|
||||||
|
inputs.append(InputParam(name=image_latent_input_name))
|
||||||
|
|
||||||
|
# Add additional batch inputs
|
||||||
|
for input_name in self._additional_batch_inputs:
|
||||||
|
inputs.append(InputParam(name=input_name))
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||||
|
for image_latent_input_name in self._image_latent_inputs:
|
||||||
|
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||||
|
if image_latent_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 1. Calculate num_frames, height/width from latents
|
||||||
|
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor_spatial)
|
||||||
|
block_state.height = block_state.height or height
|
||||||
|
block_state.width = block_state.width or width
|
||||||
|
|
||||||
|
# Process additional batch inputs (only batch expansion)
|
||||||
|
for input_name in self._additional_batch_inputs:
|
||||||
|
input_tensor = getattr(block_state, input_name)
|
||||||
|
if input_tensor is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Only expand batch size
|
||||||
|
input_tensor = repeat_tensor_to_batch_size(
|
||||||
|
input_name=input_name,
|
||||||
|
input_tensor=input_tensor,
|
||||||
|
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||||
|
batch_size=block_state.batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(block_state, input_name, input_tensor)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Prepare latents step that prepares the latents for the text-to-video generation process"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||||
|
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
|
||||||
|
),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_inputs(self, components, block_state):
|
||||||
|
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||||
|
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.prepare_latents with self->comp
|
||||||
|
def prepare_latents(
|
||||||
|
comp,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents, height, width)
|
||||||
|
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
if latents.shape != shape:
|
||||||
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||||
|
latents = latents.to(device)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
block_state.height = block_state.height or components.default_height
|
||||||
|
block_state.width = block_state.width or components.default_width
|
||||||
|
|
||||||
|
block_state.latents = self.prepare_latents(
|
||||||
|
components,
|
||||||
|
batch_size=block_state.batch_size * block_state.num_images_per_prompt,
|
||||||
|
num_channels_latents=components.num_channels_latents,
|
||||||
|
height=block_state.height,
|
||||||
|
width=block_state.width,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
generator=block_state.generator,
|
||||||
|
latents=block_state.latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageSetTimestepsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that sets the scheduler's timesteps for inference. Need to run after prepare latents step."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("latents", required=True),
|
||||||
|
InputParam("num_inference_steps", default=9),
|
||||||
|
InputParam("sigmas"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
device = components._execution_device
|
||||||
|
|
||||||
|
latent_height, latent_width = block_state.latents.shape[2], block_state.latents.shape[3]
|
||||||
|
image_seq_len = (latent_height // 2) * (latent_width // 2) # sequence length after patchify
|
||||||
|
|
||||||
|
mu = calculate_shift(
|
||||||
|
image_seq_len,
|
||||||
|
base_seq_len=components.scheduler.config.get("base_image_seq_len", 256),
|
||||||
|
max_seq_len=components.scheduler.config.get("max_image_seq_len", 4096),
|
||||||
|
base_shift=components.scheduler.config.get("base_shift", 0.5),
|
||||||
|
max_shift=components.scheduler.config.get("max_shift", 1.15),
|
||||||
|
)
|
||||||
|
components.scheduler.sigma_min = 0.0
|
||||||
|
|
||||||
|
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||||
|
components.scheduler,
|
||||||
|
block_state.num_inference_steps,
|
||||||
|
device,
|
||||||
|
sigmas=block_state.sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that sets the scheduler's timesteps for inference with strength. Need to run after set timesteps step."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("timesteps", required=True),
|
||||||
|
InputParam("num_inference_steps", required=True),
|
||||||
|
InputParam("strength", default=0.6),
|
||||||
|
]
|
||||||
|
|
||||||
|
def check_inputs(self, components, block_state):
|
||||||
|
if block_state.strength < 0.0 or block_state.strength > 1.0:
|
||||||
|
raise ValueError(f"Strength must be between 0.0 and 1.0, but got {block_state.strength}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
init_timestep = min(block_state.num_inference_steps * block_state.strength, block_state.num_inference_steps)
|
||||||
|
|
||||||
|
t_start = int(max(block_state.num_inference_steps - init_timestep, 0))
|
||||||
|
timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
|
||||||
|
if hasattr(components.scheduler, "set_begin_index"):
|
||||||
|
components.scheduler.set_begin_index(t_start * components.scheduler.order)
|
||||||
|
|
||||||
|
block_state.timesteps = timesteps
|
||||||
|
block_state.num_inference_steps = block_state.num_inference_steps - t_start
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImagePrepareLatentswithImageStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "step that prepares the latents with image condition, need to run after set timesteps and prepare latents step."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("latents", required=True),
|
||||||
|
InputParam("image_latents", required=True),
|
||||||
|
InputParam("timesteps", required=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
|
||||||
|
block_state.latents = components.scheduler.scale_noise(
|
||||||
|
block_state.image_latents, latent_timestep, block_state.latents
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
91
src/diffusers/modular_pipelines/z_image/decoders.py
Normal file
91
src/diffusers/modular_pipelines/z_image/decoders.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...image_processor import VaeImageProcessor
|
||||||
|
from ...models import AutoencoderKL
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageVaeDecoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKL),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8 * 2}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that decodes the denoised latents into images"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
name="output_type",
|
||||||
|
default="pil",
|
||||||
|
type_hint=str,
|
||||||
|
description="The type of the output images, can be 'pil', 'np', 'pt'",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"images",
|
||||||
|
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.ndarray]],
|
||||||
|
description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
vae_dtype = components.vae.dtype
|
||||||
|
|
||||||
|
latents = block_state.latents.to(vae_dtype)
|
||||||
|
latents = latents / components.vae.config.scaling_factor + components.vae.config.shift_factor
|
||||||
|
|
||||||
|
block_state.images = components.vae.decode(latents, return_dict=False)[0]
|
||||||
|
block_state.images = components.image_processor.postprocess(
|
||||||
|
block_state.images, output_type=block_state.output_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
310
src/diffusers/modular_pipelines/z_image/denoise.py
Normal file
310
src/diffusers/modular_pipelines/z_image/denoise.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
from ...models import ZImageTransformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import (
|
||||||
|
BlockState,
|
||||||
|
LoopSequentialPipelineBlocks,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
PipelineState,
|
||||||
|
)
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam
|
||||||
|
from .modular_pipeline import ZImageModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"step within the denoising loop that prepares the latent input for the denoiser. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"dtype",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.dtype,
|
||||||
|
description="The dtype of the model inputs. Can be generated in input step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||||
|
latents = block_state.latents.unsqueeze(2).to(
|
||||||
|
block_state.dtype
|
||||||
|
) # [batch_size, num_channels, 1, height, width]
|
||||||
|
block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width]
|
||||||
|
|
||||||
|
timestep = t.expand(latents.shape[0]).to(block_state.dtype)
|
||||||
|
timestep = (1000 - timestep) / 1000
|
||||||
|
block_state.timestep = timestep
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageLoopDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
guider_input_fields: Dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")},
|
||||||
|
):
|
||||||
|
"""Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
|
||||||
|
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
|
||||||
|
|
||||||
|
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
|
||||||
|
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
|
||||||
|
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
|
||||||
|
'encoder_hidden_states'.
|
||||||
|
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
|
||||||
|
`block_state.image_embeds` for both conditional and unconditional batches.
|
||||||
|
"""
|
||||||
|
if not isinstance(guider_input_fields, dict):
|
||||||
|
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||||
|
self._guider_input_fields = guider_input_fields
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
ComponentSpec("transformer", ZImageTransformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that denoise the latents with guidance. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
inputs = [
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
guider_input_names = []
|
||||||
|
uncond_guider_input_names = []
|
||||||
|
for value in self._guider_input_fields.values():
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
guider_input_names.append(value[0])
|
||||||
|
uncond_guider_input_names.append(value[1])
|
||||||
|
else:
|
||||||
|
guider_input_names.append(value)
|
||||||
|
|
||||||
|
for name in guider_input_names:
|
||||||
|
inputs.append(InputParam(name=name, required=True))
|
||||||
|
for name in uncond_guider_input_names:
|
||||||
|
inputs.append(InputParam(name=name))
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||||
|
) -> PipelineState:
|
||||||
|
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||||
|
|
||||||
|
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
|
||||||
|
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
|
||||||
|
# you will get a guider_state with two batches:
|
||||||
|
# guider_state = [
|
||||||
|
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
|
||||||
|
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
|
||||||
|
# ]
|
||||||
|
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
|
||||||
|
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||||
|
|
||||||
|
# run the denoiser for each guidance batch
|
||||||
|
for guider_state_batch in guider_state:
|
||||||
|
components.guider.prepare_models(components.transformer)
|
||||||
|
cond_kwargs = guider_state_batch.as_dict()
|
||||||
|
|
||||||
|
def _convert_dtype(v, dtype):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
return v.to(dtype)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [_convert_dtype(t, dtype) for t in v]
|
||||||
|
return v
|
||||||
|
|
||||||
|
cond_kwargs = {
|
||||||
|
k: _convert_dtype(v, block_state.dtype)
|
||||||
|
for k, v in cond_kwargs.items()
|
||||||
|
if k in self._guider_input_fields.keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Predict the noise residual
|
||||||
|
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
|
||||||
|
model_out_list = components.transformer(
|
||||||
|
x=block_state.latent_model_input,
|
||||||
|
t=block_state.timestep,
|
||||||
|
return_dict=False,
|
||||||
|
**cond_kwargs,
|
||||||
|
)[0]
|
||||||
|
noise_pred = torch.stack(model_out_list, dim=0).squeeze(2)
|
||||||
|
guider_state_batch.noise_pred = -noise_pred
|
||||||
|
components.guider.cleanup_models(components.transformer)
|
||||||
|
|
||||||
|
# Perform guidance
|
||||||
|
block_state.noise_pred = components.guider(guider_state)[0]
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"step within the denoising loop that update the latents. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `ZImageDenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||||
|
# Perform scheduler step using the predicted output
|
||||||
|
latents_dtype = block_state.latents.dtype
|
||||||
|
block_state.latents = components.scheduler.step(
|
||||||
|
block_state.noise_pred.float(),
|
||||||
|
t,
|
||||||
|
block_state.latents.float(),
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if block_state.latents.dtype != latents_dtype:
|
||||||
|
block_state.latents = block_state.latents.to(latents_dtype)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Pipeline block that iteratively denoise the latents over `timesteps`. "
|
||||||
|
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"timesteps",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.num_warmup_steps = max(
|
||||||
|
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(block_state.timesteps):
|
||||||
|
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||||
|
if i == len(block_state.timesteps) - 1 or (
|
||||||
|
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageDenoiseStep(ZImageDenoiseLoopWrapper):
|
||||||
|
block_classes = [
|
||||||
|
ZImageLoopBeforeDenoiser,
|
||||||
|
ZImageLoopDenoiser(
|
||||||
|
guider_input_fields={
|
||||||
|
"cap_feats": ("prompt_embeds", "negative_prompt_embeds"),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
ZImageLoopAfterDenoiser,
|
||||||
|
]
|
||||||
|
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoise the latents. \n"
|
||||||
|
"Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n"
|
||||||
|
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||||
|
" - `ZImageLoopBeforeDenoiser`\n"
|
||||||
|
" - `ZImageLoopDenoiser`\n"
|
||||||
|
" - `ZImageLoopAfterDenoiser`\n"
|
||||||
|
"This block supports text-to-image and image-to-image tasks for Z-Image."
|
||||||
|
)
|
||||||
344
src/diffusers/modular_pipelines/z_image/encoders.py
Normal file
344
src/diffusers/modular_pipelines/z_image/encoders.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
from transformers import Qwen2Tokenizer, Qwen3Model
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...guiders import ClassifierFreeGuidance
|
||||||
|
from ...image_processor import VaeImageProcessor
|
||||||
|
from ...models import AutoencoderKL
|
||||||
|
from ...utils import is_ftfy_available, logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import ZImageModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_ftfy_available():
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def get_qwen_prompt_embeds(
|
||||||
|
text_encoder: Qwen3Model,
|
||||||
|
tokenizer: Qwen2Tokenizer,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
device: torch.device,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
for i, prompt_item in enumerate(prompt):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": prompt_item},
|
||||||
|
]
|
||||||
|
prompt_item = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=True,
|
||||||
|
)
|
||||||
|
prompt[i] = prompt_item
|
||||||
|
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
text_input_ids = text_inputs.input_ids.to(device)
|
||||||
|
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder(
|
||||||
|
input_ids=text_input_ids,
|
||||||
|
attention_mask=prompt_masks,
|
||||||
|
output_hidden_states=True,
|
||||||
|
).hidden_states[-2]
|
||||||
|
|
||||||
|
prompt_embeds_list = []
|
||||||
|
|
||||||
|
for i in range(len(prompt_embeds)):
|
||||||
|
prompt_embeds_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||||
|
|
||||||
|
return prompt_embeds_list
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_vae_image(
|
||||||
|
image_tensor: torch.Tensor,
|
||||||
|
vae: AutoencoderKL,
|
||||||
|
generator: torch.Generator,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
latent_channels: int = 16,
|
||||||
|
):
|
||||||
|
if not isinstance(image_tensor, torch.Tensor):
|
||||||
|
raise ValueError(f"Expected image_tensor to be a tensor, got {type(image_tensor)}.")
|
||||||
|
|
||||||
|
if isinstance(generator, list) and len(generator) != image_tensor.shape[0]:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but it is not same as number of images {image_tensor.shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_tensor = image_tensor.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if isinstance(generator, list):
|
||||||
|
image_latents = [
|
||||||
|
retrieve_latents(vae.encode(image_tensor[i : i + 1]), generator=generator[i])
|
||||||
|
for i in range(image_tensor.shape[0])
|
||||||
|
]
|
||||||
|
image_latents = torch.cat(image_latents, dim=0)
|
||||||
|
else:
|
||||||
|
image_latents = retrieve_latents(vae.encode(image_tensor), generator=generator)
|
||||||
|
|
||||||
|
image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageTextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generate text_embeddings to guide the video generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("text_encoder", Qwen3Model),
|
||||||
|
ComponentSpec("tokenizer", Qwen2Tokenizer),
|
||||||
|
ComponentSpec(
|
||||||
|
"guider",
|
||||||
|
ClassifierFreeGuidance,
|
||||||
|
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("negative_prompt"),
|
||||||
|
InputParam("max_sequence_length", default=512),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
description="text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"negative_prompt_embeds",
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
description="negative text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
if block_state.prompt is not None and (
|
||||||
|
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||||
|
):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def encode_prompt(
|
||||||
|
components,
|
||||||
|
prompt: str,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
prepare_unconditional_embeds: bool = True,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Encodes the prompt into text encoder hidden states.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
prompt to be encoded
|
||||||
|
device: (`torch.device`):
|
||||||
|
torch device
|
||||||
|
prepare_unconditional_embeds (`bool`):
|
||||||
|
whether to use prepare unconditional embeddings or not
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||||
|
less than `1`).
|
||||||
|
max_sequence_length (`int`, defaults to `512`):
|
||||||
|
The maximum number of text tokens to be used for the generation process.
|
||||||
|
"""
|
||||||
|
device = device or components._execution_device
|
||||||
|
if not isinstance(prompt, list):
|
||||||
|
prompt = [prompt]
|
||||||
|
batch_size = len(prompt)
|
||||||
|
|
||||||
|
prompt_embeds = get_qwen_prompt_embeds(
|
||||||
|
text_encoder=components.text_encoder,
|
||||||
|
tokenizer=components.tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = None
|
||||||
|
if prepare_unconditional_embeds:
|
||||||
|
negative_prompt = negative_prompt or ""
|
||||||
|
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||||
|
|
||||||
|
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||||
|
raise TypeError(
|
||||||
|
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||||
|
f" {type(prompt)}."
|
||||||
|
)
|
||||||
|
elif batch_size != len(negative_prompt):
|
||||||
|
raise ValueError(
|
||||||
|
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||||
|
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||||
|
" the batch size of `prompt`."
|
||||||
|
)
|
||||||
|
|
||||||
|
negative_prompt_embeds = get_qwen_prompt_embeds(
|
||||||
|
text_encoder=components.text_encoder,
|
||||||
|
tokenizer=components.tokenizer,
|
||||||
|
prompt=negative_prompt,
|
||||||
|
max_sequence_length=max_sequence_length,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt_embeds, negative_prompt_embeds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
# Get inputs and intermediates
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
# Encode input prompt
|
||||||
|
(
|
||||||
|
block_state.prompt_embeds,
|
||||||
|
block_state.negative_prompt_embeds,
|
||||||
|
) = self.encode_prompt(
|
||||||
|
components=components,
|
||||||
|
prompt=block_state.prompt,
|
||||||
|
device=block_state.device,
|
||||||
|
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||||
|
negative_prompt=block_state.negative_prompt,
|
||||||
|
max_sequence_length=block_state.max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add outputs
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageVaeImageEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "z-image"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Vae Image Encoder step that generate condition_latents based on image to guide the image generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKL),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
VaeImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 8 * 2}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image", type_hint=PIL.Image.Image, required=True),
|
||||||
|
InputParam("height"),
|
||||||
|
InputParam("width"),
|
||||||
|
InputParam("generator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="video latent representation with the first frame image condition",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(components, block_state):
|
||||||
|
if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
|
||||||
|
block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
|
||||||
|
image = block_state.image
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
dtype = torch.float32
|
||||||
|
vae_dtype = components.vae.dtype
|
||||||
|
|
||||||
|
image_tensor = components.image_processor.preprocess(
|
||||||
|
image, height=block_state.height, width=block_state.width
|
||||||
|
).to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
block_state.image_latents = encode_vae_image(
|
||||||
|
image_tensor=image_tensor,
|
||||||
|
vae=components.vae,
|
||||||
|
generator=block_state.generator,
|
||||||
|
device=device,
|
||||||
|
dtype=vae_dtype,
|
||||||
|
latent_channels=components.num_channels_latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
191
src/diffusers/modular_pipelines/z_image/modular_blocks.py
Normal file
191
src/diffusers/modular_pipelines/z_image/modular_blocks.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
|
from ..modular_pipeline_utils import InsertableDict
|
||||||
|
from .before_denoise import (
|
||||||
|
ZImageAdditionalInputsStep,
|
||||||
|
ZImagePrepareLatentsStep,
|
||||||
|
ZImagePrepareLatentswithImageStep,
|
||||||
|
ZImageSetTimestepsStep,
|
||||||
|
ZImageSetTimestepsWithStrengthStep,
|
||||||
|
ZImageTextInputStep,
|
||||||
|
)
|
||||||
|
from .decoders import ZImageVaeDecoderStep
|
||||||
|
from .denoise import (
|
||||||
|
ZImageDenoiseStep,
|
||||||
|
)
|
||||||
|
from .encoders import (
|
||||||
|
ZImageTextEncoderStep,
|
||||||
|
ZImageVaeImageEncoderStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# z-image
|
||||||
|
# text2image
|
||||||
|
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
ZImageTextInputStep,
|
||||||
|
ZImagePrepareLatentsStep,
|
||||||
|
ZImageSetTimestepsStep,
|
||||||
|
ZImageDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||||
|
+ "This is a sequential pipeline blocks:\n"
|
||||||
|
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||||
|
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||||
|
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||||
|
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# z-image: image2image
|
||||||
|
## denoise
|
||||||
|
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
ZImageTextInputStep,
|
||||||
|
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||||
|
ZImagePrepareLatentsStep,
|
||||||
|
ZImageSetTimestepsStep,
|
||||||
|
ZImageSetTimestepsWithStrengthStep,
|
||||||
|
ZImagePrepareLatentswithImageStep,
|
||||||
|
ZImageDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = [
|
||||||
|
"input",
|
||||||
|
"additional_inputs",
|
||||||
|
"prepare_latents",
|
||||||
|
"set_timesteps",
|
||||||
|
"set_timesteps_with_strength",
|
||||||
|
"prepare_latents_with_image",
|
||||||
|
"denoise",
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||||
|
+ "This is a sequential pipeline blocks:\n"
|
||||||
|
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||||
|
+ " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||||
|
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||||
|
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||||
|
+ " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n"
|
||||||
|
+ " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n"
|
||||||
|
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
## auto blocks
|
||||||
|
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
ZImageImage2ImageCoreDenoiseStep,
|
||||||
|
ZImageCoreDenoiseStep,
|
||||||
|
]
|
||||||
|
block_names = ["image2image", "text2image"]
|
||||||
|
block_trigger_inputs = ["image_latents", None]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoise the latents. "
|
||||||
|
"This is a auto pipeline block that works for text2image and image2image tasks."
|
||||||
|
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
|
||||||
|
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
|
||||||
|
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
|
||||||
|
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [ZImageVaeImageEncoderStep]
|
||||||
|
block_names = ["vae_image_encoder"]
|
||||||
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Vae Image Encoder step that encode the image to generate the image latents"
|
||||||
|
+"This is an auto pipeline block that works for image2image tasks."
|
||||||
|
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
|
||||||
|
+" - if `image` is not provided, step will be skipped."
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||||
|
block_classes = [
|
||||||
|
ZImageTextEncoderStep,
|
||||||
|
ZImageAutoVaeImageEncoderStep,
|
||||||
|
ZImageAutoDenoiseStep,
|
||||||
|
ZImageVaeDecoderStep,
|
||||||
|
]
|
||||||
|
block_names = ["text_encoder", "vae_image_encoder", "denoise", "decode"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n"
|
||||||
|
+" - for text-to-image generation, all you need to provide is `prompt`\n"
|
||||||
|
+" - for image-to-image generation, you need to provide `image`\n"
|
||||||
|
+" - if `image` is not provided, step will be skipped."
|
||||||
|
|
||||||
|
|
||||||
|
# presets
|
||||||
|
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", ZImageTextEncoderStep),
|
||||||
|
("input", ZImageTextInputStep),
|
||||||
|
("prepare_latents", ZImagePrepareLatentsStep),
|
||||||
|
("set_timesteps", ZImageSetTimestepsStep),
|
||||||
|
("denoise", ZImageDenoiseStep),
|
||||||
|
("decode", ZImageVaeDecoderStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", ZImageTextEncoderStep),
|
||||||
|
("vae_image_encoder", ZImageVaeImageEncoderStep),
|
||||||
|
("input", ZImageTextInputStep),
|
||||||
|
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||||
|
("prepare_latents", ZImagePrepareLatentsStep),
|
||||||
|
("set_timesteps", ZImageSetTimestepsStep),
|
||||||
|
("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep),
|
||||||
|
("prepare_latents_with_image", ZImagePrepareLatentswithImageStep),
|
||||||
|
("denoise", ZImageDenoiseStep),
|
||||||
|
("decode", ZImageVaeDecoderStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", ZImageTextEncoderStep),
|
||||||
|
("vae_image_encoder", ZImageAutoVaeImageEncoderStep),
|
||||||
|
("denoise", ZImageAutoDenoiseStep),
|
||||||
|
("decode", ZImageVaeDecoderStep),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ALL_BLOCKS = {
|
||||||
|
"text2image": TEXT2IMAGE_BLOCKS,
|
||||||
|
"image2image": IMAGE2IMAGE_BLOCKS,
|
||||||
|
"auto": AUTO_BLOCKS,
|
||||||
|
}
|
||||||
72
src/diffusers/modular_pipelines/z_image/modular_pipeline.py
Normal file
72
src/diffusers/modular_pipelines/z_image/modular_pipeline.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...loaders import ZImageLoraLoaderMixin
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageModularPipeline(
|
||||||
|
ModularPipeline,
|
||||||
|
ZImageLoraLoaderMixin,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A ModularPipeline for Z-Image.
|
||||||
|
|
||||||
|
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_blocks_name = "ZImageAutoBlocks"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_height(self):
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_width(self):
|
||||||
|
return 1024
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor_spatial(self):
|
||||||
|
vae_scale_factor_spatial = 16
|
||||||
|
if hasattr(self, "image_processor") and self.image_processor is not None:
|
||||||
|
vae_scale_factor_spatial = self.image_processor.config.vae_scale_factor
|
||||||
|
return vae_scale_factor_spatial
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor(self):
|
||||||
|
vae_scale_factor = 8
|
||||||
|
if hasattr(self, "vae") and self.vae is not None:
|
||||||
|
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_latents(self):
|
||||||
|
num_channels_latents = 16
|
||||||
|
if hasattr(self, "transformer") and self.transformer is not None:
|
||||||
|
num_channels_latents = self.transformer.config.in_channels
|
||||||
|
return num_channels_latents
|
||||||
|
|
||||||
|
@property
|
||||||
|
def requires_unconditional_embeds(self):
|
||||||
|
requires_unconditional_embeds = False
|
||||||
|
|
||||||
|
if hasattr(self, "guider") and self.guider is not None:
|
||||||
|
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||||
|
|
||||||
|
return requires_unconditional_embeds
|
||||||
@@ -109,7 +109,7 @@ LIBRARIES = []
|
|||||||
for library in LOADABLE_CLASSES:
|
for library in LOADABLE_CLASSES:
|
||||||
LIBRARIES.append(library)
|
LIBRARIES.append(library)
|
||||||
|
|
||||||
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
|
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
@@ -462,8 +462,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
pipeline_is_sequentially_offloaded = any(
|
pipeline_is_sequentially_offloaded = any(
|
||||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||||
)
|
)
|
||||||
|
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
|
||||||
if is_pipeline_device_mapped:
|
if is_pipeline_device_mapped:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||||
@@ -1164,7 +1163,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
"""
|
"""
|
||||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||||
|
|
||||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||||
if is_pipeline_device_mapped:
|
if is_pipeline_device_mapped:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||||
@@ -1286,7 +1285,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||||
self.remove_all_hooks()
|
self.remove_all_hooks()
|
||||||
|
|
||||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||||
if is_pipeline_device_mapped:
|
if is_pipeline_device_mapped:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||||
@@ -2171,6 +2170,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def _is_pipeline_device_mapped(self):
|
||||||
|
# We support passing `device_map="cuda"`, for example. This is helpful, in case
|
||||||
|
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
|
||||||
|
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
|
||||||
|
device_map = self.hf_device_map
|
||||||
|
is_device_type_map = False
|
||||||
|
if isinstance(device_map, str):
|
||||||
|
try:
|
||||||
|
torch.device(device_map)
|
||||||
|
is_device_type_map = True
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusionMixin:
|
class StableDiffusionMixin:
|
||||||
r"""
|
r"""
|
||||||
|
|||||||
@@ -86,42 +86,42 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
methods the library implements for all schedulers such as loading and saving.
|
methods the library implements for all schedulers such as loading and saving.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_train_timesteps (`int`, defaults to 1000):
|
num_train_timesteps (`int`, defaults to `1000`):
|
||||||
The number of diffusion steps to train the model.
|
The number of diffusion steps to train the model.
|
||||||
beta_start (`float`, defaults to 0.0001):
|
beta_start (`float`, defaults to `0.0001`):
|
||||||
The starting `beta` value of inference.
|
The starting `beta` value of inference.
|
||||||
beta_end (`float`, defaults to 0.02):
|
beta_end (`float`, defaults to `0.02`):
|
||||||
The final `beta` value.
|
The final `beta` value.
|
||||||
beta_schedule (`str`, defaults to `"linear"`):
|
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||||
trained_betas (`np.ndarray`, *optional*):
|
trained_betas (`np.ndarray` or `List[float]`, *optional*):
|
||||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||||
solver_order (`int`, defaults to 2):
|
solver_order (`int`, defaults to `2`):
|
||||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||||
sampling, and `solver_order=3` for unconditional sampling.
|
sampling, and `solver_order=3` for unconditional sampling.
|
||||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
`sample` (directly predicts the noisy sample`), `v_prediction` (see section 2.4 of [Imagen
|
||||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
Video](https://huggingface.co/papers/2210.02303) paper), or `flow_prediction`.
|
||||||
thresholding (`bool`, defaults to `False`):
|
thresholding (`bool`, defaults to `False`):
|
||||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||||
as Stable Diffusion.
|
as Stable Diffusion.
|
||||||
dynamic_thresholding_ratio (`float`, defaults to 0.995):
|
dynamic_thresholding_ratio (`float`, defaults to `0.995`):
|
||||||
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
|
||||||
sample_max_value (`float`, defaults to 1.0):
|
sample_max_value (`float`, defaults to `1.0`):
|
||||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||||
`algorithm_type="dpmsolver++"`.
|
`algorithm_type="dpmsolver++"`.
|
||||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
||||||
Algorithm type for the solver; can be `dpmsolver` or `dpmsolver++` or `sde-dpmsolver++`. The `dpmsolver`
|
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, or `sde-dpmsolver++`. The `dpmsolver`
|
||||||
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the
|
||||||
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
`dpmsolver++` type implements the algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095)
|
||||||
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
paper. It is recommended to use `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided
|
||||||
sampling like in Stable Diffusion.
|
sampling like in Stable Diffusion.
|
||||||
solver_type (`str`, defaults to `midpoint`):
|
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
||||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||||
lower_order_final (`bool`, defaults to `True`):
|
lower_order_final (`bool`, defaults to `False`):
|
||||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||||
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
use_karras_sigmas (`bool`, *optional*, defaults to `False`):
|
||||||
@@ -132,15 +132,23 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
|
||||||
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
|
||||||
final_sigmas_type (`str`, *optional*, defaults to `"zero"`):
|
use_flow_sigmas (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||||
|
flow_shift (`float`, *optional*, defaults to `1.0`):
|
||||||
|
The flow shift parameter for flow-based models.
|
||||||
|
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
||||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||||
variance_type (`str`, *optional*):
|
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
||||||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
||||||
contains the predicted Gaussian variance.
|
output contains the predicted Gaussian variance.
|
||||||
|
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||||
|
Whether to use dynamic shifting for the noise schedule.
|
||||||
|
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||||
|
The type of time shifting to apply.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
@@ -152,27 +160,27 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
num_train_timesteps: int = 1000,
|
num_train_timesteps: int = 1000,
|
||||||
beta_start: float = 0.0001,
|
beta_start: float = 0.0001,
|
||||||
beta_end: float = 0.02,
|
beta_end: float = 0.02,
|
||||||
beta_schedule: str = "linear",
|
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||||
trained_betas: Optional[np.ndarray] = None,
|
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||||
solver_order: int = 2,
|
solver_order: int = 2,
|
||||||
prediction_type: str = "epsilon",
|
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||||
thresholding: bool = False,
|
thresholding: bool = False,
|
||||||
dynamic_thresholding_ratio: float = 0.995,
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
sample_max_value: float = 1.0,
|
sample_max_value: float = 1.0,
|
||||||
algorithm_type: str = "dpmsolver++",
|
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver++"] = "dpmsolver++",
|
||||||
solver_type: str = "midpoint",
|
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||||
lower_order_final: bool = False,
|
lower_order_final: bool = False,
|
||||||
use_karras_sigmas: Optional[bool] = False,
|
use_karras_sigmas: Optional[bool] = False,
|
||||||
use_exponential_sigmas: Optional[bool] = False,
|
use_exponential_sigmas: Optional[bool] = False,
|
||||||
use_beta_sigmas: Optional[bool] = False,
|
use_beta_sigmas: Optional[bool] = False,
|
||||||
use_flow_sigmas: Optional[bool] = False,
|
use_flow_sigmas: Optional[bool] = False,
|
||||||
flow_shift: Optional[float] = 1.0,
|
flow_shift: Optional[float] = 1.0,
|
||||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
||||||
lambda_min_clipped: float = -float("inf"),
|
lambda_min_clipped: float = -float("inf"),
|
||||||
variance_type: Optional[str] = None,
|
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||||
use_dynamic_shifting: bool = False,
|
use_dynamic_shifting: bool = False,
|
||||||
time_shift_type: str = "exponential",
|
time_shift_type: Literal["exponential"] = "exponential",
|
||||||
):
|
) -> None:
|
||||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||||
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1:
|
||||||
@@ -242,6 +250,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`List[int]`:
|
||||||
|
The list of solver orders for each timestep.
|
||||||
"""
|
"""
|
||||||
steps = num_inference_steps
|
steps = num_inference_steps
|
||||||
order = self.config.solver_order
|
order = self.config.solver_order
|
||||||
@@ -276,21 +288,29 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return orders
|
return orders
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def step_index(self):
|
def step_index(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int` or `None`:
|
||||||
|
The current step index.
|
||||||
"""
|
"""
|
||||||
return self._step_index
|
return self._step_index
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def begin_index(self):
|
def begin_index(self) -> Optional[int]:
|
||||||
"""
|
"""
|
||||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`int` or `None`:
|
||||||
|
The begin index.
|
||||||
"""
|
"""
|
||||||
return self._begin_index
|
return self._begin_index
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||||
def set_begin_index(self, begin_index: int = 0):
|
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||||
|
|
||||||
@@ -302,19 +322,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
|
|
||||||
def set_timesteps(
|
def set_timesteps(
|
||||||
self,
|
self,
|
||||||
num_inference_steps: int = None,
|
num_inference_steps: Optional[int] = None,
|
||||||
device: Union[str, torch.device] = None,
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
mu: Optional[float] = None,
|
mu: Optional[float] = None,
|
||||||
timesteps: Optional[List[int]] = None,
|
timesteps: Optional[List[int]] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`, *optional*):
|
||||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
device (`str` or `torch.device`, *optional*):
|
device (`str` or `torch.device`, *optional*):
|
||||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
mu (`float`, *optional*):
|
||||||
|
The mu parameter for dynamic shifting.
|
||||||
timesteps (`List[int]`, *optional*):
|
timesteps (`List[int]`, *optional*):
|
||||||
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
|
||||||
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
|
||||||
@@ -453,7 +475,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return sample
|
return sample
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||||
def _sigma_to_t(self, sigma, log_sigmas):
|
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Convert sigma values to corresponding timestep values through interpolation.
|
Convert sigma values to corresponding timestep values through interpolation.
|
||||||
|
|
||||||
@@ -490,7 +512,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return t
|
return t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Convert sigma values to alpha_t and sigma_t values.
|
Convert sigma values to alpha_t and sigma_t values.
|
||||||
|
|
||||||
@@ -512,7 +534,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return alpha_t, sigma_t
|
return alpha_t, sigma_t
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
||||||
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
def _convert_to_karras(self, in_sigmas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative
|
||||||
Models](https://huggingface.co/papers/2206.00364).
|
Models](https://huggingface.co/papers/2206.00364).
|
||||||
@@ -637,7 +659,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
*args,
|
*args,
|
||||||
sample: torch.Tensor = None,
|
sample: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -733,7 +755,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
*args,
|
*args,
|
||||||
sample: torch.Tensor = None,
|
sample: Optional[torch.Tensor] = None,
|
||||||
noise: Optional[torch.Tensor] = None,
|
noise: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -797,7 +819,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
model_output_list: List[torch.Tensor],
|
model_output_list: List[torch.Tensor],
|
||||||
*args,
|
*args,
|
||||||
sample: torch.Tensor = None,
|
sample: Optional[torch.Tensor] = None,
|
||||||
noise: Optional[torch.Tensor] = None,
|
noise: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -908,7 +930,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
model_output_list: List[torch.Tensor],
|
model_output_list: List[torch.Tensor],
|
||||||
*args,
|
*args,
|
||||||
sample: torch.Tensor = None,
|
sample: Optional[torch.Tensor] = None,
|
||||||
noise: Optional[torch.Tensor] = None,
|
noise: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -1030,8 +1052,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
self,
|
self,
|
||||||
model_output_list: List[torch.Tensor],
|
model_output_list: List[torch.Tensor],
|
||||||
*args,
|
*args,
|
||||||
sample: torch.Tensor = None,
|
sample: Optional[torch.Tensor] = None,
|
||||||
order: int = None,
|
order: Optional[int] = None,
|
||||||
noise: Optional[torch.Tensor] = None,
|
noise: Optional[torch.Tensor] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@@ -1125,7 +1147,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
return step_index
|
return step_index
|
||||||
|
|
||||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
|
||||||
def _init_step_index(self, timestep):
|
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the step_index counter for the scheduler.
|
Initialize the step_index counter for the scheduler.
|
||||||
|
|
||||||
@@ -1146,7 +1168,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
model_output: torch.Tensor,
|
model_output: torch.Tensor,
|
||||||
timestep: Union[int, torch.Tensor],
|
timestep: Union[int, torch.Tensor],
|
||||||
sample: torch.Tensor,
|
sample: torch.Tensor,
|
||||||
generator=None,
|
generator: Optional[torch.Generator] = None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[SchedulerOutput, Tuple]:
|
) -> Union[SchedulerOutput, Tuple]:
|
||||||
"""
|
"""
|
||||||
@@ -1156,11 +1178,13 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
Args:
|
Args:
|
||||||
model_output (`torch.Tensor`):
|
model_output (`torch.Tensor`):
|
||||||
The direct output from learned diffusion model.
|
The direct output from learned diffusion model.
|
||||||
timestep (`int`):
|
timestep (`int` or `torch.Tensor`):
|
||||||
The current discrete timestep in the diffusion chain.
|
The current discrete timestep in the diffusion chain.
|
||||||
sample (`torch.Tensor`):
|
sample (`torch.Tensor`):
|
||||||
A current instance of a sample created by the diffusion process.
|
A current instance of a sample created by the diffusion process.
|
||||||
return_dict (`bool`):
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A random number generator for stochastic sampling.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1277,5 +1301,5 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||||||
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
||||||
return noisy_samples
|
return noisy_samples
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return self.config.num_train_timesteps
|
return self.config.num_train_timesteps
|
||||||
|
|||||||
@@ -2,6 +2,36 @@
|
|||||||
from ..utils import DummyObject, requires_backends
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoBlocks(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class FluxAutoBlocks(metaclass=DummyObject):
|
class FluxAutoBlocks(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
@@ -227,6 +257,36 @@ class WanModularPipeline(metaclass=DummyObject):
|
|||||||
requires_backends(cls, ["torch", "transformers"])
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageAutoBlocks(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class ZImageModularPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class AllegroPipeline(metaclass=DummyObject):
|
class AllegroPipeline(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
0
tests/modular_pipelines/flux2/__init__.py
Normal file
0
tests/modular_pipelines/flux2/__init__.py
Normal file
93
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
93
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from diffusers.modular_pipelines import (
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2ModularPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...testing_utils import floats_tensor, torch_device
|
||||||
|
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
|
pipeline_class = Flux2ModularPipeline
|
||||||
|
pipeline_blocks_class = Flux2AutoBlocks
|
||||||
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||||
|
|
||||||
|
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||||
|
batch_params = frozenset(["prompt"])
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, seed=0):
|
||||||
|
generator = self.get_generator(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||||
|
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||||
|
"text_encoder_out_layers": (1,),
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 4.0,
|
||||||
|
"height": 32,
|
||||||
|
"width": 32,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_float16_inference(self):
|
||||||
|
super().test_float16_inference(9e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
|
pipeline_class = Flux2ModularPipeline
|
||||||
|
pipeline_blocks_class = Flux2AutoBlocks
|
||||||
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||||
|
|
||||||
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||||
|
batch_params = frozenset(["prompt", "image"])
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, seed=0):
|
||||||
|
generator = self.get_generator(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||||
|
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||||
|
"text_encoder_out_layers": (1,),
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 4.0,
|
||||||
|
"height": 32,
|
||||||
|
"width": 32,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||||
|
inputs["image"] = init_image
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_float16_inference(self):
|
||||||
|
super().test_float16_inference(9e-2)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||||
|
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||||
|
return
|
||||||
@@ -26,6 +26,7 @@ from diffusers.modular_pipelines import (
|
|||||||
QwenImageModularPipeline,
|
QwenImageModularPipeline,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from ...testing_utils import torch_device
|
||||||
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -104,6 +105,16 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
|||||||
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
def test_multi_images_as_input(self):
|
||||||
|
inputs = self.get_dummy_inputs()
|
||||||
|
image = inputs.pop("image")
|
||||||
|
inputs["image"] = [image, image]
|
||||||
|
|
||||||
|
pipe = self.get_pipeline().to(torch_device)
|
||||||
|
_ = pipe(
|
||||||
|
**inputs,
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
|
||||||
def test_num_images_per_prompt(self):
|
def test_num_images_per_prompt(self):
|
||||||
super().test_num_images_per_prompt()
|
super().test_num_images_per_prompt()
|
||||||
@@ -117,4 +128,4 @@ class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, Modul
|
|||||||
super().test_inference_batch_single_identical()
|
super().test_inference_batch_single_identical()
|
||||||
|
|
||||||
def test_guider_cfg(self):
|
def test_guider_cfg(self):
|
||||||
super().test_guider_cfg(1e-3)
|
super().test_guider_cfg(1e-6)
|
||||||
|
|||||||
@@ -165,7 +165,6 @@ class ModularPipelineTesterMixin:
|
|||||||
expected_max_diff=1e-4,
|
expected_max_diff=1e-4,
|
||||||
):
|
):
|
||||||
pipe = self.get_pipeline().to(torch_device)
|
pipe = self.get_pipeline().to(torch_device)
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs()
|
inputs = self.get_dummy_inputs()
|
||||||
|
|
||||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user