mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
11 Commits
cache-docs
...
flux2-modu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
921b959b9a | ||
|
|
9391a5465d | ||
|
|
b010a8ce0c | ||
|
|
d780d1a42a | ||
|
|
9264459f88 | ||
|
|
1b91856d0e | ||
|
|
01e355516b | ||
|
|
6bf668c4d2 | ||
|
|
e6d4612309 | ||
|
|
a88a7b4f03 | ||
|
|
c8656ed73c |
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
## FirstBlockCacheConfig
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ Change the [`~ComponentSpec.default_creation_method`] to `from_pretrained` and u
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.default_creation_method="from_pretrained"
|
||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.subfolder="pag_guider"
|
||||
pag_guider = guider_spec.load()
|
||||
t2i_pipeline.update_components(guider=pag_guider)
|
||||
|
||||
@@ -313,14 +313,14 @@ unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
)
|
||||
|
||||
# modify to load from a different repository
|
||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
# load component with modified spec
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
|
||||
@@ -66,8 +66,4 @@ config = FasterCacheConfig(
|
||||
tensor_format="BFCHW",
|
||||
)
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler to implement generically for a wide range of models and has been integrated first for experimental purposes.
|
||||
```
|
||||
@@ -157,7 +157,7 @@ guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.default_creation_method="from_pretrained"
|
||||
guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.pretrained_model_name_or_path="YiYiXu/modular-loader-t2i-guider"
|
||||
guider_spec.subfolder="pag_guider"
|
||||
pag_guider = guider_spec.load()
|
||||
t2i_pipeline.update_components(guider=pag_guider)
|
||||
|
||||
@@ -313,14 +313,14 @@ unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
repo='RunDiffusion/Juggernaut-XL-v9',
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
)
|
||||
|
||||
# 修改以从不同的仓库加载
|
||||
unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
# 使用修改后的规范加载组件
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
|
||||
@@ -37,7 +37,7 @@ from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
@@ -46,7 +46,12 @@ import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import cast_training_params, compute_snr
|
||||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -708,6 +713,56 @@ def main():
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
unet_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if isinstance(model, type(unwrap_model(unet))):
|
||||
unet_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# returns a tuple of state dictionary and network alphas
|
||||
lora_state_dict, network_alphas = StableDiffusionPipeline.lora_state_dict(input_dir)
|
||||
|
||||
unet_state_dict = {f"{k.replace('unet.', '')}": v for k, v in lora_state_dict.items() if k.startswith("unet.")}
|
||||
unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
|
||||
incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default")
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
# throw warning if some unexpected keys are found and continue loading
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Make sure the trainable params are in float32
|
||||
if args.mixed_precision in ["fp16"]:
|
||||
cast_training_params([unet_], dtype=torch.float32)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
|
||||
@@ -732,6 +787,10 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Register the hooks for efficient saving and loading of LoRA weights
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
@@ -906,17 +965,6 @@ def main():
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
|
||||
unwrapped_unet = unwrap_model(unet)
|
||||
unet_lora_state_dict = convert_state_dict_to_diffusers(
|
||||
get_peft_model_state_dict(unwrapped_unet)
|
||||
)
|
||||
|
||||
StableDiffusionPipeline.save_lora_weights(
|
||||
save_directory=save_path,
|
||||
unet_lora_layers=unet_lora_state_dict,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
|
||||
@@ -399,6 +399,8 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
"FluxAutoBlocks",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
@@ -1091,6 +1093,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2ModularPipeline,
|
||||
FluxAutoBlocks,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
|
||||
@@ -389,6 +389,14 @@ def is_valid_url(url):
|
||||
return False
|
||||
|
||||
|
||||
def _is_single_file_path_or_url(pretrained_model_name_or_path):
|
||||
if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path):
|
||||
return False
|
||||
|
||||
repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path)
|
||||
return bool(repo_id and weight_name)
|
||||
|
||||
|
||||
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
if not is_valid_url(pretrained_model_name_or_path):
|
||||
raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.")
|
||||
@@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
|
||||
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
|
||||
match = re.match(pattern, pretrained_model_name_or_path)
|
||||
if not match:
|
||||
logger.warning("Unable to identify the repo_id and weights_name from the provided URL.")
|
||||
return repo_id, weights_name
|
||||
|
||||
repo_id = f"{match.group(1)}/{match.group(2)}"
|
||||
|
||||
@@ -41,11 +41,9 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
@@ -69,7 +69,10 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
|
||||
weight_dtype = self.mlp[0].weight.dtype
|
||||
if weight_dtype.is_floating_point:
|
||||
t_freq = t_freq.to(weight_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
@@ -126,6 +129,10 @@ class ZSingleStreamAttnProcessor:
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
@@ -306,6 +313,10 @@ class RopeEmbedder:
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
@@ -317,6 +328,8 @@ class RopeEmbedder:
|
||||
class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
_repeated_blocks = ["ZImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
@@ -553,8 +566,6 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
@@ -572,6 +583,9 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
@@ -52,6 +52,10 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2ModularPipeline",
|
||||
]
|
||||
_import_structure["qwenimage"] = [
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageModularPipeline",
|
||||
@@ -71,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
123
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
123
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = [
|
||||
"Flux2TextEncoderStep",
|
||||
"Flux2RemoteTextEncoderStep",
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2VaeEncoderStep",
|
||||
]
|
||||
_import_structure["before_denoise"] = [
|
||||
"Flux2SetTimestepsStep",
|
||||
"Flux2PrepareLatentsStep",
|
||||
"Flux2RoPEInputsStep",
|
||||
"Flux2PrepareImageLatentsStep",
|
||||
]
|
||||
_import_structure["denoise"] = [
|
||||
"Flux2LoopDenoiser",
|
||||
"Flux2LoopAfterDenoiser",
|
||||
"Flux2DenoiseLoopWrapper",
|
||||
"Flux2DenoiseStep",
|
||||
]
|
||||
_import_structure["decoders"] = ["Flux2DecodeStep"]
|
||||
_import_structure["inputs"] = [
|
||||
"Flux2TextInputStep",
|
||||
"Flux2ImageInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBeforeDenoiseStep",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoDecodeStep",
|
||||
"Flux2AutoDenoiseStep",
|
||||
"Flux2AutoImageInputStep",
|
||||
"Flux2AutoTextEncoderStep",
|
||||
"Flux2AutoTextInputStep",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2BeforeDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import (
|
||||
Flux2DenoiseLoopWrapper,
|
||||
Flux2DenoiseStep,
|
||||
Flux2LoopAfterDenoiser,
|
||||
Flux2LoopDenoiser,
|
||||
)
|
||||
from .encoders import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2ImageInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBeforeDenoiseStep,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoDecodeStep,
|
||||
Flux2AutoDenoiseStep,
|
||||
Flux2AutoImageInputStep,
|
||||
Flux2AutoTextEncoderStep,
|
||||
Flux2AutoTextInputStep,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2BeforeDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
505
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
505
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
@@ -0,0 +1,505 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
"""Compute empirical mu for Flux2 timestep scheduling."""
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50),
|
||||
InputParam("timesteps"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("guidance_scale", default=4.0),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||
OutputParam(
|
||||
"num_inference_steps",
|
||||
type_hint=int,
|
||||
description="The number of denoising steps to perform at inference time",
|
||||
),
|
||||
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.device = components._execution_device
|
||||
|
||||
scheduler = components.scheduler
|
||||
|
||||
height = block_state.height or components.default_height
|
||||
width = block_state.width or components.default_width
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
|
||||
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
timesteps = block_state.timesteps
|
||||
|
||||
if timesteps is None and sigmas is None:
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps,
|
||||
block_state.device,
|
||||
timesteps=timesteps,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||
guidance = guidance.expand(batch_size)
|
||||
block_state.guidance = guidance
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("height", type_hint=int),
|
||||
InputParam("width", type_hint=int),
|
||||
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||
InputParam("generator"),
|
||||
InputParam(
|
||||
"batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||
),
|
||||
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||
),
|
||||
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(components, block_state):
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
|
||||
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_ids(latents: torch.Tensor):
|
||||
"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents: Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
Position IDs tensor of shape (B, H*W, 4)
|
||||
"""
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1)
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1)
|
||||
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents):
|
||||
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def prepare_latents(
|
||||
comp,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
block_state.device = components._execution_device
|
||||
block_state.num_channels_latents = components.num_channels_latents
|
||||
|
||||
self.check_inputs(components, block_state)
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
latents = self.prepare_latents(
|
||||
components,
|
||||
batch_size,
|
||||
block_state.num_channels_latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.dtype,
|
||||
block_state.device,
|
||||
block_state.generator,
|
||||
block_state.latents,
|
||||
)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(block_state.device)
|
||||
|
||||
latents = self._pack_latents(latents)
|
||||
|
||||
block_state.latents = latents
|
||||
block_state.latent_ids = latent_ids
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="prompt_embeds", required=True),
|
||||
InputParam(name="latent_ids"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="txt_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||
),
|
||||
OutputParam(
|
||||
name="latent_ids",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||
"""Prepare 4D position IDs for text tokens."""
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
seq_l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
device = prompt_embeds.device
|
||||
|
||||
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image_latents", type_hint=List[torch.Tensor]),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_images_per_prompt", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning",
|
||||
),
|
||||
OutputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
|
||||
"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
Args:
|
||||
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
|
||||
scale: Factor used to define the time separation between latents.
|
||||
|
||||
Returns:
|
||||
Combined coordinate tensor of shape (1, N_total, 4)
|
||||
"""
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents):
|
||||
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
image_latents = block_state.image_latents
|
||||
|
||||
if image_latents is None:
|
||||
block_state.image_latents = None
|
||||
block_state.image_latent_ids = None
|
||||
else:
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
packed = self._pack_latents(latent)
|
||||
packed = packed.squeeze(0)
|
||||
packed_latents.append(packed)
|
||||
|
||||
image_latents = torch.cat(packed_latents, dim=0)
|
||||
image_latents = image_latents.unsqueeze(0)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
block_state.image_latents = image_latents
|
||||
block_state.image_latent_ids = image_latent_ids
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("output_type", default="pil"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The denoised latents from the denoising step",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for the latents, used for unpacking",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Unpack latents using position IDs to scatter tokens into place.
|
||||
|
||||
Args:
|
||||
x: Packed latents tensor of shape (B, seq_len, C)
|
||||
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
|
||||
|
||||
Returns:
|
||||
Unpacked latents tensor of shape (B, C, H, W)
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def _unpatchify_latents(latents):
|
||||
"""Convert patchified latents back to regular format."""
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
vae = components.vae
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.images = block_state.latents
|
||||
else:
|
||||
latents = block_state.latents
|
||||
latent_ids = block_state.latent_ids
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
|
||||
latents = self._unpatchify_latents(latents)
|
||||
|
||||
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||
block_state.images = components.image_processor.postprocess(
|
||||
block_state.images, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import Flux2Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("joint_attention_kwargs"),
|
||||
InputParam(
|
||||
"latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||
),
|
||||
InputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||
),
|
||||
InputParam(
|
||||
"guidance",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Guidance scale as a tensor",
|
||||
),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Mistral3",
|
||||
),
|
||||
InputParam(
|
||||
"txt_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for text tokens (T, H, W, L)",
|
||||
),
|
||||
InputParam(
|
||||
"latent_ids",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||
) -> PipelineState:
|
||||
latents = block_state.latents
|
||||
latent_model_input = latents.to(components.transformer.dtype)
|
||||
img_ids = block_state.latent_ids
|
||||
|
||||
image_latents = getattr(block_state, "image_latents", None)
|
||||
if image_latents is not None:
|
||||
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = components.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=block_state.guidance,
|
||||
encoder_hidden_states=block_state.prompt_embeds,
|
||||
txt_ids=block_state.txt_ids,
|
||||
img_ids=img_ids,
|
||||
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
noise_pred = noise_pred[:, : latents.size(1)]
|
||||
block_state.noise_pred = noise_pred
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that updates the latents after denoising. "
|
||||
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [InputParam("generator")]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
latents_dtype = block_state.latents.dtype
|
||||
block_state.latents = components.scheduler.step(
|
||||
block_state.noise_pred,
|
||||
t,
|
||||
block_state.latents,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if block_state.latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
block_state.latents = block_state.latents.to(latents_dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pipeline block that iteratively denoises the latents over `timesteps`. "
|
||||
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
|
||||
)
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"timesteps",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="The timesteps to use for the denoising process.",
|
||||
),
|
||||
InputParam(
|
||||
"num_inference_steps",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="The number of inference steps to use for the denoising process.",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
|
||||
block_names = ["denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||
" - `Flux2LoopDenoiser`\n"
|
||||
" - `Flux2LoopAfterDenoiser`\n"
|
||||
"This block supports both text-to-image and image-conditioned generation."
|
||||
)
|
||||
420
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
420
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLFlux2
|
||||
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def format_text_input(prompts: List[str], system_message: str = None):
|
||||
"""Format prompts for Mistral3 chat template."""
|
||||
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": system_message}],
|
||||
},
|
||||
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||
]
|
||||
for prompt in cleaned_txt
|
||||
]
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
# fmt: off
|
||||
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||
# fmt: on
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
|
||||
ComponentSpec("tokenizer", AutoProcessor),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||
InputParam("joint_attention_kwargs"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from Mistral3 used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_mistral_3_prompt_embeds(
|
||||
text_encoder: Mistral3ForConditionalGeneration,
|
||||
tokenizer: AutoProcessor,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
# fmt: off
|
||||
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||
# fmt: on
|
||||
hidden_states_layers: Tuple[int] = (10, 20, 30),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||
|
||||
inputs = tokenizer.apply_chat_template(
|
||||
messages_batch,
|
||||
add_generation_prompt=False,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
input_ids = inputs["input_ids"].to(device)
|
||||
attention_mask = inputs["attention_mask"].to(device)
|
||||
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
device=block_state.device,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
system_message=self.DEFAULT_SYSTEM_MESSAGE,
|
||||
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings using a remote API endpoint"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Text embeddings from remote API used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
prompt = block_state.prompt
|
||||
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||
"Please make sure to only forward one of the two."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
import io
|
||||
|
||||
import requests
|
||||
from huggingface_hub import get_token
|
||||
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
block_state.device = components._execution_device
|
||||
|
||||
if block_state.prompt_embeds is not None:
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
prompt = block_state.prompt
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
response = requests.post(
|
||||
self.REMOTE_URL,
|
||||
json={"prompt": prompt},
|
||||
headers={
|
||||
"Authorization": f"Bearer {get_token()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
Flux2ImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("image"),
|
||||
InputParam("height"),
|
||||
InputParam("width"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
|
||||
block_state = self.get_block_state(state)
|
||||
images = block_state.image
|
||||
|
||||
if images is None:
|
||||
block_state.condition_images = None
|
||||
else:
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
condition_images = []
|
||||
for img in images:
|
||||
components.image_processor.check_image_input(img)
|
||||
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = components.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
condition_img = components.image_processor.preprocess(
|
||||
img, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
condition_images.append(condition_img)
|
||||
|
||||
if block_state.height is None:
|
||||
block_state.height = image_height
|
||||
if block_state.width is None:
|
||||
block_state.width = image_width
|
||||
|
||||
block_state.condition_images = condition_images
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [ComponentSpec("vae", AutoencoderKLFlux2)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("condition_images", type_hint=List[torch.Tensor]),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="List of latent representations for each reference image",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _patchify_latents(latents):
|
||||
"""Convert latents to patchified format for Flux2."""
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
|
||||
"""Encode a single image using Flux2 VAE with batch norm normalization."""
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
|
||||
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
condition_images = block_state.condition_images
|
||||
|
||||
if condition_images is None:
|
||||
block_state.image_latents = None
|
||||
else:
|
||||
device = components._execution_device
|
||||
dtype = components.vae.dtype
|
||||
|
||||
image_latents = []
|
||||
for image in condition_images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
latent = self._encode_vae_image(
|
||||
vae=components.vae,
|
||||
image=image,
|
||||
generator=block_state.generator,
|
||||
)
|
||||
image_latents.append(latent)
|
||||
|
||||
block_state.image_latents = image_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
140
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
140
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import InputParam, OutputParam
|
||||
from .modular_pipeline import Flux2ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Flux2TextInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Text input processing step that standardizes text embeddings for Flux2 pipeline.\n"
|
||||
"This step:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam(
|
||||
"prompt_embeds",
|
||||
required=True,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Text embeddings used to guide the image generation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class Flux2ImageInputStep(ModularPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Image input processing step that prepares image latents for Flux2 conditioning.\n"
|
||||
"This step expands image latents to match the batch size."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_images_per_prompt", default=1),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("image_latents", type_hint=torch.Tensor),
|
||||
InputParam("image_latent_ids", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Packed image latents expanded to batch size",
|
||||
),
|
||||
OutputParam(
|
||||
"image_latent_ids",
|
||||
type_hint=torch.Tensor,
|
||||
description="Image latent position IDs expanded to batch size",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
image_latents = block_state.image_latents
|
||||
image_latent_ids = block_state.image_latent_ids
|
||||
target_batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
if image_latents is not None:
|
||||
block_state.image_latents = image_latents.repeat(target_batch_size, 1, 1)
|
||||
|
||||
if image_latent_ids is not None:
|
||||
block_state.image_latent_ids = image_latent_ids.repeat(target_batch_size, 1, 1)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
237
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
237
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2ImageInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoTextInputStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2TextInputStep]
|
||||
block_names = ["text_input"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Text input step that processes text embeddings and determines batch size.\n"
|
||||
" - `Flux2TextInputStep` is always used."
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoImageInputStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2ImageInputStep]
|
||||
block_names = ["image_input"]
|
||||
block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Image input step that expands image latents to match batch size.\n"
|
||||
" - `Flux2ImageInputStep` is used when `image_latents` is provided.\n"
|
||||
" - Skipped when no image conditioning is used."
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2VaeEncoderBlocks.values()
|
||||
block_names = Flux2VaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
|
||||
|
||||
|
||||
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2VaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoTextEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2RemoteTextEncoderStep, Flux2TextEncoderStep]
|
||||
block_names = ["remote", "local"]
|
||||
block_trigger_inputs = ["remote_text_encoder", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Text encoder step that generates text embeddings to guide the image generation.\n"
|
||||
"This is an auto pipeline block that selects between local and remote text encoding.\n"
|
||||
" - `Flux2RemoteTextEncoderStep` is used when `remote_text_encoder=True`.\n"
|
||||
" - `Flux2TextEncoderStep` is used otherwise (default)."
|
||||
)
|
||||
|
||||
|
||||
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||
|
||||
|
||||
class Flux2AutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
block_classes = [Flux2BeforeDenoiseStep]
|
||||
block_names = ["before_denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepares the inputs for the denoise step.\n"
|
||||
"This is an auto pipeline block for Flux2.\n"
|
||||
" - `Flux2BeforeDenoiseStep` is used for both text-to-image and image-conditioned generation."
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2DenoiseStep]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises the latents. "
|
||||
"This is an auto pipeline block that works for Flux2 text-to-image and image-conditioned tasks."
|
||||
" - `Flux2DenoiseStep` (denoise) for text-to-image and image-conditioned tasks."
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [Flux2DecodeStep]
|
||||
block_names = ["decode"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the denoised latents into image outputs.\n - `Flux2DecodeStep`"
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2AutoTextEncoderStep()),
|
||||
("text_input", Flux2AutoTextInputStep()),
|
||||
("image_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("image_input", Flux2AutoImageInputStep()),
|
||||
("before_denoise", Flux2AutoBeforeDenoiseStep()),
|
||||
("denoise", Flux2AutoDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
|
||||
"- For text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("preprocess_images", Flux2ProcessImagesInputStep()),
|
||||
("vae_encoder", Flux2VaeEncoderStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("image_input", Flux2ImageInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for Flux2.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "Flux2AutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@property
|
||||
def default_sample_size(self):
|
||||
return 128
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
vae_scale_factor = 8
|
||||
if getattr(self, "vae", None) is not None:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
num_channels_latents = 32
|
||||
if getattr(self, "transformer", None):
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
return num_channels_latents
|
||||
@@ -58,6 +58,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("wan", "WanModularPipeline"),
|
||||
("flux", "FluxModularPipeline"),
|
||||
("flux-kontext", "FluxKontextModularPipeline"),
|
||||
("flux2", "Flux2ModularPipeline"),
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
@@ -360,7 +361,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
collection: Optional[str] = None,
|
||||
) -> "ModularPipeline":
|
||||
"""
|
||||
create a ModularPipeline, optionally accept modular_repo to load from hub.
|
||||
create a ModularPipeline, optionally accept pretrained_model_name_or_path to load from hub.
|
||||
"""
|
||||
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
@@ -1585,7 +1586,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
for name, config_spec in self._config_specs.items():
|
||||
default_configs[name] = config_spec.default
|
||||
self.register_to_config(**default_configs)
|
||||
|
||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||
|
||||
@property
|
||||
@@ -1645,8 +1645,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
|
||||
Path to a pretrained pipeline configuration. It will first try to load config from
|
||||
`modular_model_index.json`, then fallback to `model_index.json` for compatibility with standard
|
||||
non-modular repositories. If the repo does not contain any pipeline config, it will be set to None
|
||||
during initialization.
|
||||
non-modular repositories. If the pretrained_model_name_or_path does not contain any pipeline config, it
|
||||
will be set to None during initialization.
|
||||
trust_remote_code (`bool`, optional):
|
||||
Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
|
||||
pipeline blocks based on the custom code in `pretrained_model_name_or_path`
|
||||
@@ -1807,7 +1807,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
library, class_name = None, None
|
||||
|
||||
# extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
|
||||
# e.g. {"repo": "stabilityai/stable-diffusion-2-1",
|
||||
# e.g. {"pretrained_model_name_or_path": "stabilityai/stable-diffusion-2-1",
|
||||
# "type_hint": ("diffusers", "UNet2DConditionModel"),
|
||||
# "subfolder": "unet",
|
||||
# "variant": None,
|
||||
@@ -2111,8 +2111,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
**kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
|
||||
- a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
|
||||
- a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
|
||||
`variant`, `revision`, etc.
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||
- if potentially override ComponentSpec if passed a different loading field in kwargs, e.g.
|
||||
`pretrained_model_name_or_path`, `variant`, `revision`, etc.
|
||||
"""
|
||||
|
||||
if names is None:
|
||||
@@ -2378,10 +2380,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- "type_hint": Tuple[str, str]
|
||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||
- "repo": Optional[str]
|
||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "pretrained_model_name_or_path": Optional[str]
|
||||
The model pretrained_model_name_or_pathsitory (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "subfolder": Optional[str]
|
||||
A subfolder within the repo where this component lives.
|
||||
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||
- "variant": Optional[str]
|
||||
An optional variant identifier for the model.
|
||||
- "revision": Optional[str]
|
||||
@@ -2398,11 +2400,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
Example:
|
||||
>>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
|
||||
UNet2DConditionModel >>> spec = ComponentSpec(
|
||||
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
|
||||
subfolder="subfolder", ... variant=None, ... revision=None, ...
|
||||
default_creation_method="from_pretrained",
|
||||
... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ...
|
||||
pretrained_model_name_or_path="path/to/pretrained_model_name_or_path", ... subfolder="subfolder", ...
|
||||
variant=None, ... revision=None, ... default_creation_method="from_pretrained",
|
||||
... ) >>> ModularPipeline._component_spec_to_dict(spec) {
|
||||
"type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
|
||||
"type_hint": ("diffusers", "UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo",
|
||||
"subfolder": "subfolder", "variant": None, "revision": None, "type_hint": ("diffusers",
|
||||
"UNet2DConditionModel"), "pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder",
|
||||
"variant": None, "revision": None,
|
||||
}
|
||||
"""
|
||||
@@ -2432,10 +2436,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- "type_hint": Tuple[str, str]
|
||||
Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
|
||||
- All loading fields defined by `component_spec.loading_fields()`, typically:
|
||||
- "repo": Optional[str]
|
||||
- "pretrained_model_name_or_path": Optional[str]
|
||||
The model repository (e.g., "stabilityai/stable-diffusion-xl").
|
||||
- "subfolder": Optional[str]
|
||||
A subfolder within the repo where this component lives.
|
||||
A subfolder within the pretrained_model_name_or_path where this component lives.
|
||||
- "variant": Optional[str]
|
||||
An optional variant identifier for the model.
|
||||
- "revision": Optional[str]
|
||||
@@ -2452,11 +2456,20 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
ComponentSpec: A reconstructed ComponentSpec object.
|
||||
|
||||
Example:
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
|
||||
"stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
|
||||
} >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
|
||||
subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||
ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||
revision=None, default_creation_method="from_pretrained"
|
||||
>>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ...
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant":
|
||||
None, ... "revision": None, ... } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict)
|
||||
ComponentSpec(
|
||||
name="unet", type_hint=UNet2DConditionModel, config=None,
|
||||
pretrained_model_name_or_path="stabilityai/stable-diffusion-xl", subfolder="unet", variant=None,
|
||||
revision=None, default_creation_method="from_pretrained"
|
||||
)
|
||||
"""
|
||||
# make a shallow copy so we can pop() safely
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import is_torch_available, logging
|
||||
|
||||
|
||||
@@ -80,10 +81,10 @@ class ComponentSpec:
|
||||
type_hint: Type of the component (e.g. UNet2DConditionModel)
|
||||
description: Optional description of the component
|
||||
config: Optional config dict for __init__ creation
|
||||
repo: Optional repo path for from_pretrained creation
|
||||
subfolder: Optional subfolder in repo
|
||||
variant: Optional variant in repo
|
||||
revision: Optional revision in repo
|
||||
pretrained_model_name_or_path: Optional pretrained_model_name_or_path path for from_pretrained creation
|
||||
subfolder: Optional subfolder in pretrained_model_name_or_path
|
||||
variant: Optional variant in pretrained_model_name_or_path
|
||||
revision: Optional revision in pretrained_model_name_or_path
|
||||
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
|
||||
"""
|
||||
|
||||
@@ -91,13 +92,20 @@ class ComponentSpec:
|
||||
type_hint: Optional[Type] = None
|
||||
description: Optional[str] = None
|
||||
config: Optional[FrozenDict] = None
|
||||
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
pretrained_model_name_or_path: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
|
||||
subfolder: Optional[str] = field(default="", metadata={"loading": True})
|
||||
variant: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
revision: Optional[str] = field(default=None, metadata={"loading": True})
|
||||
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
|
||||
|
||||
# Deprecated
|
||||
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": False})
|
||||
|
||||
def __post_init__(self):
|
||||
repo_value = self.repo
|
||||
if repo_value is not None and self.pretrained_model_name_or_path is None:
|
||||
object.__setattr__(self, "pretrained_model_name_or_path", repo_value)
|
||||
|
||||
def __hash__(self):
|
||||
"""Make ComponentSpec hashable, using load_id as the hash value."""
|
||||
return hash((self.name, self.load_id, self.default_creation_method))
|
||||
@@ -182,8 +190,8 @@ class ComponentSpec:
|
||||
@property
|
||||
def load_id(self) -> str:
|
||||
"""
|
||||
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
|
||||
segments).
|
||||
Unique identifier for this spec's pretrained load, composed of
|
||||
pretrained_model_name_or_path|subfolder|variant|revision (no empty segments).
|
||||
"""
|
||||
if self.default_creation_method == "from_config":
|
||||
return "null"
|
||||
@@ -197,12 +205,13 @@ class ComponentSpec:
|
||||
Decode a load_id string back into a dictionary of loading fields and values.
|
||||
|
||||
Args:
|
||||
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
|
||||
load_id: The load_id string to decode, format: "pretrained_model_name_or_path|subfolder|variant|revision"
|
||||
where None values are represented as "null"
|
||||
|
||||
Returns:
|
||||
Dict mapping loading field names to their values. e.g. {
|
||||
"repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
|
||||
"pretrained_model_name_or_path": "path/to/repo", "subfolder": "subfolder", "variant": "variant",
|
||||
"revision": "revision"
|
||||
} If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
|
||||
component not created with `load` method).
|
||||
"""
|
||||
@@ -259,34 +268,45 @@ class ComponentSpec:
|
||||
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
|
||||
def load(self, **kwargs) -> Any:
|
||||
"""Load component using from_pretrained."""
|
||||
|
||||
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
|
||||
# select loading fields from kwargs passed from user: e.g. pretrained_model_name_or_path, subfolder, variant, revision, note the list could change
|
||||
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
|
||||
# merge loading field value in the spec with user passed values to create load_kwargs
|
||||
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
|
||||
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
|
||||
repo = load_kwargs.pop("repo", None)
|
||||
if repo is None:
|
||||
|
||||
pretrained_model_name_or_path = load_kwargs.pop("pretrained_model_name_or_path", None)
|
||||
if pretrained_model_name_or_path is None:
|
||||
raise ValueError(
|
||||
"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
|
||||
"`pretrained_model_name_or_path` info is required when using `load` method (you can directly set it in `pretrained_model_name_or_path` field of the ComponentSpec or pass it as an argument)"
|
||||
)
|
||||
is_single_file = _is_single_file_path_or_url(pretrained_model_name_or_path)
|
||||
if is_single_file and self.type_hint is None:
|
||||
raise ValueError(
|
||||
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
|
||||
)
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
|
||||
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||
# update type_hint if AutoModel load successfully
|
||||
self.type_hint = component.__class__
|
||||
else:
|
||||
# determine load method
|
||||
load_method = (
|
||||
getattr(self.type_hint, "from_single_file")
|
||||
if is_single_file
|
||||
else getattr(self.type_hint, "from_pretrained")
|
||||
)
|
||||
|
||||
try:
|
||||
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
|
||||
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} using load method: {e}")
|
||||
|
||||
self.repo = repo
|
||||
self.pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||
for k, v in load_kwargs.items():
|
||||
setattr(self, k, v)
|
||||
component._diffusers_load_id = self.load_id
|
||||
|
||||
@@ -861,7 +861,6 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
torch.save({"pred": latents}, "pred_d.pt")
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
|
||||
@@ -165,21 +165,16 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -193,8 +188,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
@@ -206,12 +199,9 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
assert num_images_per_prompt == 1
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
@@ -417,8 +407,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
assert self.dtype == torch.bfloat16
|
||||
dtype = self.dtype
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -434,10 +422,6 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
lora_scale = (
|
||||
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
||||
)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@@ -455,11 +439,8 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -475,6 +456,14 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
@@ -523,12 +512,12 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents if latents.dtype == dtype else latents.to(dtype)
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
@@ -543,11 +532,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:batch_size]
|
||||
neg_out = model_out_list[batch_size:]
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(batch_size):
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
@@ -588,11 +577,11 @@ class ZImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = latents.to(dtype)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -429,7 +429,22 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -452,6 +467,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -401,6 +401,17 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -808,7 +819,22 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise NotImplementedError("only support log-rho multistep deis now")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -831,6 +857,10 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -927,6 +957,21 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -127,18 +127,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
The starting `beta` value of inference.
|
||||
beta_end (`float`, defaults to 0.02):
|
||||
The final `beta` value.
|
||||
beta_schedule (`str`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
|
||||
beta_schedule (`"linear"`, `"scaled_linear"`, or `"squaredcos_cap_v2"`, defaults to `"linear"`):
|
||||
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model.
|
||||
trained_betas (`np.ndarray`, *optional*):
|
||||
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
|
||||
solver_order (`int`, defaults to 2):
|
||||
The DPMSolver order which can be `1` or `2` or `3`. It is recommended to use `solver_order=2` for guided
|
||||
sampling, and `solver_order=3` for unconditional sampling.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample), `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://imagen.research.google/video/paper.pdf) paper), or `flow_prediction`.
|
||||
prediction_type (`"epsilon"`, `"sample"`, `"v_prediction"`, or `"flow_prediction"`, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. `epsilon` predicts the noise of the diffusion process, `sample`
|
||||
directly predicts the noisy sample, `v_prediction` predicts the velocity (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper), and `flow_prediction` predicts the flow.
|
||||
thresholding (`bool`, defaults to `False`):
|
||||
Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
|
||||
as Stable Diffusion.
|
||||
@@ -147,15 +146,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value (`float`, defaults to 1.0):
|
||||
The threshold value for dynamic thresholding. Valid only when `thresholding=True` and
|
||||
`algorithm_type="dpmsolver++"`.
|
||||
algorithm_type (`str`, defaults to `dpmsolver++`):
|
||||
Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The
|
||||
`dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927)
|
||||
paper, and the `dpmsolver++` type implements the algorithms in the
|
||||
[DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or
|
||||
`sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`str`, defaults to `midpoint`):
|
||||
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
|
||||
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
algorithm_type (`"dpmsolver"`, `"dpmsolver++"`, `"sde-dpmsolver"`, or `"sde-dpmsolver++"`, defaults to `"dpmsolver++"`):
|
||||
Algorithm type for the solver. The `dpmsolver` type implements the algorithms in the
|
||||
[DPMSolver](https://huggingface.co/papers/2206.00927) paper, and the `dpmsolver++` type implements the
|
||||
algorithms in the [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use
|
||||
`dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion.
|
||||
solver_type (`"midpoint"` or `"heun"`, defaults to `"midpoint"`):
|
||||
Solver type for the second-order solver. The solver type slightly affects the sample quality, especially
|
||||
for a small number of steps. It is recommended to use `midpoint` solvers.
|
||||
lower_order_final (`bool`, defaults to `True`):
|
||||
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
|
||||
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
|
||||
@@ -179,16 +177,16 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
|
||||
flow_shift (`float`, *optional*, defaults to 1.0):
|
||||
The shift value for the timestep schedule for flow matching.
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
final_sigmas_type (`"zero"` or `"sigma_min"`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||
lambda_min_clipped (`float`, defaults to `-inf`):
|
||||
Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the
|
||||
cosine (`squaredcos_cap_v2`) noise schedule.
|
||||
variance_type (`str`, *optional*):
|
||||
Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output
|
||||
contains the predicted Gaussian variance.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
variance_type (`"learned"` or `"learned_range"`, *optional*):
|
||||
Set to `"learned"` or `"learned_range"` for diffusion models that predict variance. If set, the model's
|
||||
output contains the predicted Gaussian variance.
|
||||
timestep_spacing (`"linspace"`, `"leading"`, or `"trailing"`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
steps_offset (`int`, defaults to 0):
|
||||
@@ -197,6 +195,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
|
||||
dark samples instead of limiting it to samples with medium brightness. Loosely related to
|
||||
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
|
||||
use_dynamic_shifting (`bool`, defaults to `False`):
|
||||
Whether to use dynamic shifting for the timestep schedule.
|
||||
time_shift_type (`"exponential"`, defaults to `"exponential"`):
|
||||
The type of time shift to apply when using dynamic shifting.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
@@ -208,15 +210,15 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
solver_order: int = 2,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "sample", "v_prediction", "flow_prediction"] = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
algorithm_type: Literal["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"] = "dpmsolver++",
|
||||
solver_type: Literal["midpoint", "heun"] = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
euler_at_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
@@ -225,14 +227,14 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
use_lu_lambdas: Optional[bool] = False,
|
||||
use_flow_sigmas: Optional[bool] = False,
|
||||
flow_shift: Optional[float] = 1.0,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
final_sigmas_type: Optional[Literal["zero", "sigma_min"]] = "zero",
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
variance_type: Optional[str] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
variance_type: Optional[Literal["learned", "learned_range"]] = None,
|
||||
timestep_spacing: Literal["linspace", "leading", "trailing"] = "linspace",
|
||||
steps_offset: int = 0,
|
||||
rescale_betas_zero_snr: bool = False,
|
||||
use_dynamic_shifting: bool = False,
|
||||
time_shift_type: str = "exponential",
|
||||
time_shift_type: Literal["exponential"] = "exponential",
|
||||
):
|
||||
if self.config.use_beta_sigmas and not is_scipy_available():
|
||||
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
|
||||
@@ -331,19 +333,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
mu: Optional[float] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
mu (`float`, *optional*):
|
||||
The mu parameter for dynamic shifting. If provided, requires `use_dynamic_shifting=True` and
|
||||
`time_shift_type="exponential"`.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to support arbitrary timesteps schedule. If `None`, timesteps will be generated
|
||||
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
|
||||
@@ -503,7 +508,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
||||
def _sigma_to_t(self, sigma, log_sigmas):
|
||||
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Convert sigma values to corresponding timestep values through interpolation.
|
||||
|
||||
@@ -539,7 +544,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
t = t.reshape(sigma.shape)
|
||||
return t
|
||||
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -588,8 +604,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Lu et al. (2022)."""
|
||||
def _convert_to_lu(self, in_lambdas: torch.Tensor, num_inference_steps: int) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule as proposed in [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model
|
||||
Sampling in Around 10 Steps](https://huggingface.co/papers/2206.00927) by Lu et al. (2022).
|
||||
|
||||
Args:
|
||||
in_lambdas (`torch.Tensor`):
|
||||
The input lambda values to be converted.
|
||||
num_inference_steps (`int`):
|
||||
The number of inference steps to generate the noise schedule for.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The converted lambda values following the Lu noise schedule.
|
||||
"""
|
||||
|
||||
lambda_min: float = in_lambdas[-1].item()
|
||||
lambda_max: float = in_lambdas[0].item()
|
||||
@@ -1069,7 +1098,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
return x_t
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1088,9 +1132,13 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return step_index
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1105,7 +1153,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: torch.Tensor,
|
||||
timestep: Union[int, torch.Tensor],
|
||||
sample: torch.Tensor,
|
||||
generator=None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
variance_noise: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
@@ -1115,22 +1163,22 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`int`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
variance_noise (`torch.Tensor`):
|
||||
variance_noise (`torch.Tensor`, *optional*):
|
||||
Alternative to generating noise with `generator` by directly providing the noise for the variance
|
||||
itself. Useful for methods such as [`LEdits++`].
|
||||
return_dict (`bool`):
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
If `return_dict` is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
|
||||
tuple is returned where the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
@@ -1210,6 +1258,21 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -413,6 +413,17 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
|
||||
@@ -491,6 +491,17 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -1079,7 +1090,22 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(f"Order must be 1, 2, 3, got {order}")
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1102,6 +1128,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1204,6 +1234,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -578,7 +578,22 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -601,6 +616,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -423,6 +423,17 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -1103,7 +1114,22 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1126,6 +1152,10 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
|
||||
@@ -513,6 +513,17 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
||||
def _sigma_to_alpha_sigma_t(self, sigma):
|
||||
"""
|
||||
Convert sigma values to alpha_t and sigma_t values.
|
||||
|
||||
Args:
|
||||
sigma (`torch.Tensor`):
|
||||
The sigma value(s) to convert.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor]`:
|
||||
A tuple containing (alpha_t, sigma_t) values.
|
||||
"""
|
||||
if self.config.use_flow_sigmas:
|
||||
alpha_t = 1 - sigma
|
||||
sigma_t = sigma
|
||||
@@ -984,7 +995,22 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return x_t
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
def index_for_timestep(
|
||||
self, timestep: Union[int, torch.Tensor], schedule_timesteps: Optional[torch.Tensor] = None
|
||||
) -> int:
|
||||
"""
|
||||
Find the index for a given timestep in the schedule.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The timestep for which to find the index.
|
||||
schedule_timesteps (`torch.Tensor`, *optional*):
|
||||
The timestep schedule to search in. If `None`, uses `self.timesteps`.
|
||||
|
||||
Returns:
|
||||
`int`:
|
||||
The index of the timestep in the schedule.
|
||||
"""
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
@@ -1007,6 +1033,10 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
def _init_step_index(self, timestep):
|
||||
"""
|
||||
Initialize the step_index counter for the scheduler.
|
||||
|
||||
Args:
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
The current timestep for which to initialize the step index.
|
||||
"""
|
||||
|
||||
if self.begin_index is None:
|
||||
@@ -1119,6 +1149,21 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise: torch.Tensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add noise to the original samples according to the noise schedule at the specified timesteps.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.Tensor`):
|
||||
The original samples without noise.
|
||||
noise (`torch.Tensor`):
|
||||
The noise to add to the samples.
|
||||
timesteps (`torch.IntTensor`):
|
||||
The timesteps at which to add noise to the samples.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The noisy samples.
|
||||
"""
|
||||
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||
|
||||
@@ -36,7 +36,7 @@ from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
@@ -62,7 +62,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
|
||||
class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxModularPipeline
|
||||
pipeline_blocks_class = FluxAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
@@ -128,7 +128,7 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
|
||||
class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = FluxKontextModularPipeline
|
||||
pipeline_blocks_class = FluxKontextAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
0
tests/modular_pipelines/flux2/__init__.py
Normal file
0
tests/modular_pipelines/flux2/__init__.py
Normal file
139
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
139
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2ModularPipeline,
|
||||
ModularPipeline,
|
||||
)
|
||||
from diffusers.modular_pipelines.flux2 import (
|
||||
Flux2AutoTextEncoderStep,
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
)
|
||||
|
||||
from ...testing_utils import floats_tensor, torch_device
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 4.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = Flux2ModularPipeline
|
||||
pipeline_blocks_class = Flux2AutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||
batch_params = frozenset(["prompt", "image"])
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||
"text_encoder_out_layers": (1,),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 4.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
}
|
||||
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||
inputs["image"] = init_image
|
||||
|
||||
return inputs
|
||||
|
||||
def test_save_from_pretrained(self):
|
||||
pipes = []
|
||||
base_pipe = self.get_pipeline().to(torch_device)
|
||||
pipes.append(base_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
base_pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe.to(torch_device)
|
||||
|
||||
pipes.append(pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
inputs = self.get_dummy_inputs()
|
||||
image = pipe(**inputs, output="images")
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_float16_inference(self):
|
||||
super().test_float16_inference(9e-2)
|
||||
|
||||
|
||||
class TestFlux2AutoTextEncoderStep(unittest.TestCase):
|
||||
def test_auto_text_encoder_block_classes(self):
|
||||
auto_step = Flux2AutoTextEncoderStep()
|
||||
|
||||
assert len(auto_step.block_classes) == 2
|
||||
assert Flux2RemoteTextEncoderStep in auto_step.block_classes
|
||||
assert Flux2TextEncoderStep in auto_step.block_classes
|
||||
|
||||
def test_auto_text_encoder_trigger_inputs(self):
|
||||
auto_step = Flux2AutoTextEncoderStep()
|
||||
|
||||
assert auto_step.block_trigger_inputs == ["remote_text_encoder", None]
|
||||
|
||||
def test_auto_text_encoder_block_names(self):
|
||||
auto_step = Flux2AutoTextEncoderStep()
|
||||
|
||||
assert auto_step.block_names == ["remote", "local"]
|
||||
@@ -32,7 +32,7 @@ from ..test_modular_pipelines_common import ModularGuiderTesterMixin, ModularPip
|
||||
class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageModularPipeline
|
||||
pipeline_blocks_class = QwenImageAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -58,7 +58,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
|
||||
class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-modular"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
|
||||
@@ -84,7 +84,7 @@ class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, ModularGu
|
||||
class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, ModularGuiderTesterMixin):
|
||||
pipeline_class = QwenImageEditPlusModularPipeline
|
||||
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"
|
||||
|
||||
# No `mask_image` yet.
|
||||
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
|
||||
|
||||
@@ -105,7 +105,7 @@ class SDXLModularIPAdapterTesterMixin:
|
||||
|
||||
blocks = self.pipeline_blocks_class()
|
||||
_ = blocks.sub_blocks.pop("ip_adapter")
|
||||
pipe = blocks.init_pipeline(self.repo)
|
||||
pipe = blocks.init_pipeline(self.pretrained_model_name_or_path)
|
||||
pipe.load_components(torch_dtype=torch.float32)
|
||||
pipe = pipe.to(torch_device)
|
||||
|
||||
@@ -278,7 +278,7 @@ class TestSDXLModularPipelineFast(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -325,7 +325,7 @@ class TestSDXLImg2ImgModularPipelineFast(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
@@ -378,7 +378,7 @@ class SDXLInpaintingModularPipelineFastTests(
|
||||
|
||||
pipeline_class = StableDiffusionXLModularPipeline
|
||||
pipeline_blocks_class = StableDiffusionXLAutoBlocks
|
||||
repo = "hf-internal-testing/tiny-sdxl-modular"
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-sdxl-modular"
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
|
||||
@@ -43,9 +43,9 @@ class ModularPipelineTesterMixin:
|
||||
)
|
||||
|
||||
@property
|
||||
def repo(self) -> str:
|
||||
def pretrained_model_name_or_path(self) -> str:
|
||||
raise NotImplementedError(
|
||||
"You need to set the attribute `repo` in the child test class. See existing pipeline tests for reference."
|
||||
"You need to set the attribute `pretrained_model_name_or_path` in the child test class. See existing pipeline tests for reference."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -103,7 +103,9 @@ class ModularPipelineTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
|
||||
pipeline = self.pipeline_blocks_class().init_pipeline(
|
||||
self.pretrained_model_name_or_path, components_manager=components_manager
|
||||
)
|
||||
pipeline.load_components(torch_dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
return pipeline
|
||||
|
||||
306
tests/pipelines/z_image/test_z_image.py
Normal file
306
tests/pipelines/z_image/test_z_image.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImagePipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin, to_np
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
# Note: Some tests (test_float16_inference, test_save_load_float16) may fail in full suite
|
||||
# due to RopeEmbedder cache state pollution between tests. They pass when run individually.
|
||||
# This is a known test isolation issue, not a functional bug.
|
||||
|
||||
|
||||
class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = ZImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
supports_dduf = False
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = ZImageTransformer2DModel(
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=32,
|
||||
n_layers=2,
|
||||
n_refiner_layers=1,
|
||||
n_heads=2,
|
||||
n_kv_heads=2,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=16,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
block_out_channels=[32, 64],
|
||||
layers_per_block=1,
|
||||
latent_channels=16,
|
||||
norm_num_groups=32,
|
||||
sample_size=32,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
config = Qwen3Config(
|
||||
hidden_size=16,
|
||||
intermediate_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
text_encoder = Qwen3Model(config)
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "bad quality",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.0,
|
||||
"cfg_normalization": False,
|
||||
"cfg_truncation": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
generated_image = image[0]
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=5e-2))
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
|
||||
|
||||
def test_num_images_per_prompt(self):
|
||||
import inspect
|
||||
|
||||
sig = inspect.signature(self.pipeline_class.__call__)
|
||||
|
||||
if "num_images_per_prompt" not in sig.parameters:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
batch_sizes = [1, 2]
|
||||
num_images_per_prompts = [1, 2]
|
||||
|
||||
for batch_size in batch_sizes:
|
||||
for num_images_per_prompt in num_images_per_prompts:
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
for key in inputs.keys():
|
||||
if key in self.batch_params:
|
||||
inputs[key] = batch_size * [inputs[key]]
|
||||
|
||||
images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0]
|
||||
|
||||
assert images.shape[0] == batch_size * num_images_per_prompt
|
||||
|
||||
del pipe
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def test_attention_slicing_forward_pass(
|
||||
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
|
||||
):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
if test_max_difference:
|
||||
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
|
||||
self.assertLess(
|
||||
max(max_diff1, max_diff2),
|
||||
expected_max_diff,
|
||||
"Attention slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_vae_tiling(self, expected_diff_max: float = 0.2):
|
||||
generator_device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
# Without tiling
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_without_tiling = pipe(**inputs)[0]
|
||||
|
||||
# With tiling (standard AutoencoderKL doesn't accept parameters)
|
||||
pipe.vae.enable_tiling()
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
inputs["height"] = inputs["width"] = 128
|
||||
output_with_tiling = pipe(**inputs)[0]
|
||||
|
||||
self.assertLess(
|
||||
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
|
||||
expected_diff_max,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
def test_pipeline_with_accelerator_device_map(self, expected_max_difference=5e-4):
|
||||
# Z-Image RoPE embeddings (complex64) have slightly higher numerical tolerance
|
||||
super().test_pipeline_with_accelerator_device_map(expected_max_difference=expected_max_difference)
|
||||
|
||||
def test_group_offloading_inference(self):
|
||||
# Block-level offloading conflicts with RoPE cache. Pipeline-level offloading (tested separately) works fine.
|
||||
self.skipTest("Using test_pipeline_level_group_offloading_inference instead")
|
||||
|
||||
def test_save_load_float16(self, expected_max_diff=1e-2):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
super().test_save_load_float16(expected_max_diff=expected_max_diff)
|
||||
Reference in New Issue
Block a user