mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-17 01:44:43 +08:00
Compare commits
11 Commits
device-map
...
flux2-modu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
21ecee5655 | ||
|
|
04de34b82e | ||
|
|
2e8c97b734 | ||
|
|
3806a9add3 | ||
|
|
75876748e5 | ||
|
|
771512a46d | ||
|
|
b0f50c64e1 | ||
|
|
921b959b9a | ||
|
|
9391a5465d | ||
|
|
d780d1a42a | ||
|
|
9264459f88 |
@@ -399,6 +399,8 @@ except OptionalDependencyNotAvailable:
|
|||||||
else:
|
else:
|
||||||
_import_structure["modular_pipelines"].extend(
|
_import_structure["modular_pipelines"].extend(
|
||||||
[
|
[
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2ModularPipeline",
|
||||||
"FluxAutoBlocks",
|
"FluxAutoBlocks",
|
||||||
"FluxKontextAutoBlocks",
|
"FluxKontextAutoBlocks",
|
||||||
"FluxKontextModularPipeline",
|
"FluxKontextModularPipeline",
|
||||||
@@ -1091,6 +1093,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
else:
|
else:
|
||||||
from .modular_pipelines import (
|
from .modular_pipelines import (
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2ModularPipeline,
|
||||||
FluxAutoBlocks,
|
FluxAutoBlocks,
|
||||||
FluxKontextAutoBlocks,
|
FluxKontextAutoBlocks,
|
||||||
FluxKontextModularPipeline,
|
FluxKontextModularPipeline,
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ else:
|
|||||||
"FluxKontextAutoBlocks",
|
"FluxKontextAutoBlocks",
|
||||||
"FluxKontextModularPipeline",
|
"FluxKontextModularPipeline",
|
||||||
]
|
]
|
||||||
|
_import_structure["flux2"] = [
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2ModularPipeline",
|
||||||
|
]
|
||||||
_import_structure["qwenimage"] = [
|
_import_structure["qwenimage"] = [
|
||||||
"QwenImageAutoBlocks",
|
"QwenImageAutoBlocks",
|
||||||
"QwenImageModularPipeline",
|
"QwenImageModularPipeline",
|
||||||
@@ -71,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
else:
|
else:
|
||||||
from .components_manager import ComponentsManager
|
from .components_manager import ComponentsManager
|
||||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||||
|
from .flux2 import Flux2AutoBlocks, Flux2ModularPipeline
|
||||||
from .modular_pipeline import (
|
from .modular_pipeline import (
|
||||||
AutoPipelineBlocks,
|
AutoPipelineBlocks,
|
||||||
BlockState,
|
BlockState,
|
||||||
|
|||||||
111
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
111
src/diffusers/modular_pipelines/flux2/__init__.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from ...utils import (
|
||||||
|
DIFFUSERS_SLOW_IMPORT,
|
||||||
|
OptionalDependencyNotAvailable,
|
||||||
|
_LazyModule,
|
||||||
|
get_objects_from_module,
|
||||||
|
is_torch_available,
|
||||||
|
is_transformers_available,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_dummy_objects = {}
|
||||||
|
_import_structure = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure["encoders"] = [
|
||||||
|
"Flux2TextEncoderStep",
|
||||||
|
"Flux2RemoteTextEncoderStep",
|
||||||
|
"Flux2VaeEncoderStep",
|
||||||
|
]
|
||||||
|
_import_structure["before_denoise"] = [
|
||||||
|
"Flux2SetTimestepsStep",
|
||||||
|
"Flux2PrepareLatentsStep",
|
||||||
|
"Flux2RoPEInputsStep",
|
||||||
|
"Flux2PrepareImageLatentsStep",
|
||||||
|
]
|
||||||
|
_import_structure["denoise"] = [
|
||||||
|
"Flux2LoopDenoiser",
|
||||||
|
"Flux2LoopAfterDenoiser",
|
||||||
|
"Flux2DenoiseLoopWrapper",
|
||||||
|
"Flux2DenoiseStep",
|
||||||
|
]
|
||||||
|
_import_structure["decoders"] = ["Flux2DecodeStep"]
|
||||||
|
_import_structure["inputs"] = [
|
||||||
|
"Flux2ProcessImagesInputStep",
|
||||||
|
"Flux2TextInputStep",
|
||||||
|
]
|
||||||
|
_import_structure["modular_blocks"] = [
|
||||||
|
"ALL_BLOCKS",
|
||||||
|
"AUTO_BLOCKS",
|
||||||
|
"REMOTE_AUTO_BLOCKS",
|
||||||
|
"TEXT2IMAGE_BLOCKS",
|
||||||
|
"IMAGE_CONDITIONED_BLOCKS",
|
||||||
|
"Flux2AutoBlocks",
|
||||||
|
"Flux2AutoVaeEncoderStep",
|
||||||
|
"Flux2BeforeDenoiseStep",
|
||||||
|
"Flux2VaeEncoderSequentialStep",
|
||||||
|
]
|
||||||
|
_import_structure["modular_pipeline"] = ["Flux2ModularPipeline"]
|
||||||
|
|
||||||
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||||
|
else:
|
||||||
|
from .before_denoise import (
|
||||||
|
Flux2PrepareImageLatentsStep,
|
||||||
|
Flux2PrepareLatentsStep,
|
||||||
|
Flux2RoPEInputsStep,
|
||||||
|
Flux2SetTimestepsStep,
|
||||||
|
)
|
||||||
|
from .decoders import Flux2DecodeStep
|
||||||
|
from .denoise import (
|
||||||
|
Flux2DenoiseLoopWrapper,
|
||||||
|
Flux2DenoiseStep,
|
||||||
|
Flux2LoopAfterDenoiser,
|
||||||
|
Flux2LoopDenoiser,
|
||||||
|
)
|
||||||
|
from .encoders import (
|
||||||
|
Flux2RemoteTextEncoderStep,
|
||||||
|
Flux2TextEncoderStep,
|
||||||
|
Flux2VaeEncoderStep,
|
||||||
|
)
|
||||||
|
from .inputs import (
|
||||||
|
Flux2ProcessImagesInputStep,
|
||||||
|
Flux2TextInputStep,
|
||||||
|
)
|
||||||
|
from .modular_blocks import (
|
||||||
|
ALL_BLOCKS,
|
||||||
|
AUTO_BLOCKS,
|
||||||
|
IMAGE_CONDITIONED_BLOCKS,
|
||||||
|
REMOTE_AUTO_BLOCKS,
|
||||||
|
TEXT2IMAGE_BLOCKS,
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2AutoVaeEncoderStep,
|
||||||
|
Flux2BeforeDenoiseStep,
|
||||||
|
Flux2VaeEncoderSequentialStep,
|
||||||
|
)
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
else:
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
508
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
508
src/diffusers/modular_pipelines/flux2/before_denoise.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...models import Flux2Transformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import logging
|
||||||
|
from ...utils.torch_utils import randn_tensor
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||||
|
"""Compute empirical mu for Flux2 timestep scheduling."""
|
||||||
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
|
a2, b2 = 0.00016927, 0.45666666
|
||||||
|
|
||||||
|
if image_seq_len > 4300:
|
||||||
|
mu = a2 * image_seq_len + b2
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
m_200 = a2 * image_seq_len + b2
|
||||||
|
m_10 = a1 * image_seq_len + b1
|
||||||
|
|
||||||
|
a = (m_200 - m_10) / 190.0
|
||||||
|
b = m_200 - 200.0 * a
|
||||||
|
mu = a * num_steps + b
|
||||||
|
|
||||||
|
return float(mu)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||||
|
def retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps: Optional[int] = None,
|
||||||
|
device: Optional[Union[str, torch.device]] = None,
|
||||||
|
timesteps: Optional[List[int]] = None,
|
||||||
|
sigmas: Optional[List[float]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||||
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scheduler (`SchedulerMixin`):
|
||||||
|
The scheduler to get timesteps from.
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||||
|
must be `None`.
|
||||||
|
device (`str` or `torch.device`, *optional*):
|
||||||
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
timesteps (`List[int]`, *optional*):
|
||||||
|
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||||
|
`num_inference_steps` and `sigmas` must be `None`.
|
||||||
|
sigmas (`List[float]`, *optional*):
|
||||||
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||||
|
`num_inference_steps` and `timesteps` must be `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||||
|
second element is the number of inference steps.
|
||||||
|
"""
|
||||||
|
if timesteps is not None and sigmas is not None:
|
||||||
|
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||||
|
if timesteps is not None:
|
||||||
|
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accepts_timesteps:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
elif sigmas is not None:
|
||||||
|
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||||
|
if not accept_sigmas:
|
||||||
|
raise ValueError(
|
||||||
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||||
|
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||||
|
)
|
||||||
|
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
num_inference_steps = len(timesteps)
|
||||||
|
else:
|
||||||
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||||
|
timesteps = scheduler.timesteps
|
||||||
|
return timesteps, num_inference_steps
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2SetTimestepsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that sets the scheduler's timesteps for Flux2 inference using empirical mu calculation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_inference_steps", default=50),
|
||||||
|
InputParam("timesteps"),
|
||||||
|
InputParam("sigmas"),
|
||||||
|
InputParam("guidance_scale", default=4.0),
|
||||||
|
InputParam("latents", type_hint=torch.Tensor),
|
||||||
|
InputParam("num_images_per_prompt", default=1),
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
|
||||||
|
OutputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of denoising steps to perform at inference time",
|
||||||
|
),
|
||||||
|
OutputParam("guidance", type_hint=torch.Tensor, description="Guidance scale tensor"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
scheduler = components.scheduler
|
||||||
|
|
||||||
|
height = block_state.height or components.default_height
|
||||||
|
width = block_state.width or components.default_width
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
|
||||||
|
latent_height = 2 * (int(height) // (vae_scale_factor * 2))
|
||||||
|
latent_width = 2 * (int(width) // (vae_scale_factor * 2))
|
||||||
|
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||||
|
|
||||||
|
num_inference_steps = block_state.num_inference_steps
|
||||||
|
sigmas = block_state.sigmas
|
||||||
|
timesteps = block_state.timesteps
|
||||||
|
|
||||||
|
if timesteps is None and sigmas is None:
|
||||||
|
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||||
|
if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
|
||||||
|
sigmas = None
|
||||||
|
|
||||||
|
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||||
|
|
||||||
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
|
scheduler,
|
||||||
|
num_inference_steps,
|
||||||
|
block_state.device,
|
||||||
|
timesteps=timesteps,
|
||||||
|
sigmas=sigmas,
|
||||||
|
mu=mu,
|
||||||
|
)
|
||||||
|
block_state.timesteps = timesteps
|
||||||
|
block_state.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
|
||||||
|
guidance = guidance.expand(batch_size)
|
||||||
|
block_state.guidance = guidance
|
||||||
|
|
||||||
|
components.scheduler.set_begin_index(0)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2PrepareLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Prepare latents step that prepares the initial noise latents for Flux2 text-to-image generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("height", type_hint=int),
|
||||||
|
InputParam("width", type_hint=int),
|
||||||
|
InputParam("latents", type_hint=Optional[torch.Tensor]),
|
||||||
|
InputParam("num_images_per_prompt", type_hint=int, default=1),
|
||||||
|
InputParam("generator"),
|
||||||
|
InputParam(
|
||||||
|
"batch_size",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.",
|
||||||
|
),
|
||||||
|
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
|
||||||
|
),
|
||||||
|
OutputParam("latent_ids", type_hint=torch.Tensor, description="Position IDs for the latents (for RoPE)"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(components, block_state):
|
||||||
|
vae_scale_factor = components.vae_scale_factor
|
||||||
|
if (block_state.height is not None and block_state.height % (vae_scale_factor * 2) != 0) or (
|
||||||
|
block_state.width is not None and block_state.width % (vae_scale_factor * 2) != 0
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"`height` and `width` have to be divisible by {vae_scale_factor * 2} but are {block_state.height} and {block_state.width}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_latent_ids(latents: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latents: Latent tensor of shape (B, C, H, W)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Position IDs tensor of shape (B, H*W, 4)
|
||||||
|
"""
|
||||||
|
batch_size, _, height, width = latents.shape
|
||||||
|
|
||||||
|
t = torch.arange(1)
|
||||||
|
h = torch.arange(height)
|
||||||
|
w = torch.arange(width)
|
||||||
|
l = torch.arange(1)
|
||||||
|
|
||||||
|
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||||
|
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||||
|
|
||||||
|
return latent_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_latents(latents):
|
||||||
|
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||||
|
batch_size, num_channels, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def prepare_latents(
|
||||||
|
comp,
|
||||||
|
batch_size,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents=None,
|
||||||
|
):
|
||||||
|
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
|
||||||
|
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
|
||||||
|
|
||||||
|
shape = (batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||||
|
)
|
||||||
|
if latents is None:
|
||||||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
latents = latents.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
block_state.height = block_state.height or components.default_height
|
||||||
|
block_state.width = block_state.width or components.default_width
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
block_state.num_channels_latents = components.num_channels_latents
|
||||||
|
|
||||||
|
self.check_inputs(components, block_state)
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
components,
|
||||||
|
batch_size,
|
||||||
|
block_state.num_channels_latents,
|
||||||
|
block_state.height,
|
||||||
|
block_state.width,
|
||||||
|
block_state.dtype,
|
||||||
|
block_state.device,
|
||||||
|
block_state.generator,
|
||||||
|
block_state.latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
latent_ids = self._prepare_latent_ids(latents)
|
||||||
|
latent_ids = latent_ids.to(block_state.device)
|
||||||
|
|
||||||
|
latents = self._pack_latents(latents)
|
||||||
|
|
||||||
|
block_state.latents = latents
|
||||||
|
block_state.latent_ids = latent_ids
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2RoPEInputsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that prepares the 4D RoPE position IDs for Flux2 denoising. Should be placed after text encoder and latent preparation steps."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(name="prompt_embeds", required=True),
|
||||||
|
InputParam(name="latent_ids"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
name="txt_ids",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs (T, H, W, L) for text tokens, used for RoPE calculation.",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
name="latent_ids",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs (T, H, W, L) for image latents, used for RoPE calculation.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_text_ids(x: torch.Tensor, t_coord: Optional[torch.Tensor] = None):
|
||||||
|
"""Prepare 4D position IDs for text tokens."""
|
||||||
|
B, L, _ = x.shape
|
||||||
|
out_ids = []
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||||
|
h = torch.arange(1)
|
||||||
|
w = torch.arange(1)
|
||||||
|
seq_l = torch.arange(L)
|
||||||
|
|
||||||
|
coords = torch.cartesian_prod(t, h, w, seq_l)
|
||||||
|
out_ids.append(coords)
|
||||||
|
|
||||||
|
return torch.stack(out_ids)
|
||||||
|
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
prompt_embeds = block_state.prompt_embeds
|
||||||
|
device = prompt_embeds.device
|
||||||
|
|
||||||
|
block_state.txt_ids = self._prepare_text_ids(prompt_embeds)
|
||||||
|
block_state.txt_ids = block_state.txt_ids.to(device)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2PrepareImageLatentsStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that prepares image latents and their position IDs for Flux2 image conditioning."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image_latents", type_hint=List[torch.Tensor]),
|
||||||
|
InputParam("batch_size", required=True, type_hint=int),
|
||||||
|
InputParam("num_images_per_prompt", default=1, type_hint=int),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Packed image latents for conditioning",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"image_latent_ids",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for image latents",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_image_ids(image_latents: List[torch.Tensor], scale: int = 10):
|
||||||
|
"""
|
||||||
|
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_latents: A list of image latent feature tensors of shape (1, C, H, W).
|
||||||
|
scale: Factor used to define the time separation between latents.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined coordinate tensor of shape (1, N_total, 4)
|
||||||
|
"""
|
||||||
|
if not isinstance(image_latents, list):
|
||||||
|
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||||
|
|
||||||
|
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||||
|
t_coords = [t.view(-1) for t in t_coords]
|
||||||
|
|
||||||
|
image_latent_ids = []
|
||||||
|
for x, t in zip(image_latents, t_coords):
|
||||||
|
x = x.squeeze(0)
|
||||||
|
_, height, width = x.shape
|
||||||
|
|
||||||
|
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||||
|
image_latent_ids.append(x_ids)
|
||||||
|
|
||||||
|
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||||
|
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
return image_latent_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _pack_latents(latents):
|
||||||
|
"""Pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)"""
|
||||||
|
batch_size, num_channels, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
image_latents = block_state.image_latents
|
||||||
|
|
||||||
|
if image_latents is None:
|
||||||
|
block_state.image_latents = None
|
||||||
|
block_state.image_latent_ids = None
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||||
|
|
||||||
|
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||||
|
|
||||||
|
packed_latents = []
|
||||||
|
for latent in image_latents:
|
||||||
|
packed = self._pack_latents(latent)
|
||||||
|
packed = packed.squeeze(0)
|
||||||
|
packed_latents.append(packed)
|
||||||
|
|
||||||
|
image_latents = torch.cat(packed_latents, dim=0)
|
||||||
|
image_latents = image_latents.unsqueeze(0)
|
||||||
|
|
||||||
|
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||||
|
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||||
|
image_latent_ids = image_latent_ids.to(device)
|
||||||
|
|
||||||
|
block_state.image_latents = image_latents
|
||||||
|
block_state.image_latent_ids = image_latent_ids
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
146
src/diffusers/modular_pipelines/flux2/decoders.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...models import AutoencoderKLFlux2
|
||||||
|
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DecodeStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("vae", AutoencoderKLFlux2),
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
Flux2ImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Step that decodes the denoised latents into images using Flux2 VAE with batch norm denormalization"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("output_type", default="pil"),
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The denoised latents from the denoising step",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"latent_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for the latents, used for unpacking",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"images",
|
||||||
|
type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
|
||||||
|
description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Unpack latents using position IDs to scatter tokens into place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Packed latents tensor of shape (B, seq_len, C)
|
||||||
|
x_ids: Position IDs tensor of shape (B, seq_len, 4) with (T, H, W, L) coordinates
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unpacked latents tensor of shape (B, C, H, W)
|
||||||
|
"""
|
||||||
|
x_list = []
|
||||||
|
for data, pos in zip(x, x_ids):
|
||||||
|
_, ch = data.shape # noqa: F841
|
||||||
|
h_ids = pos[:, 1].to(torch.int64)
|
||||||
|
w_ids = pos[:, 2].to(torch.int64)
|
||||||
|
|
||||||
|
h = torch.max(h_ids) + 1
|
||||||
|
w = torch.max(w_ids) + 1
|
||||||
|
|
||||||
|
flat_ids = h_ids * w + w_ids
|
||||||
|
|
||||||
|
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||||
|
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||||
|
|
||||||
|
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||||
|
x_list.append(out)
|
||||||
|
|
||||||
|
return torch.stack(x_list, dim=0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unpatchify_latents(latents):
|
||||||
|
"""Convert patchified latents back to regular format."""
|
||||||
|
batch_size, num_channels_latents, height, width = latents.shape
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||||
|
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
vae = components.vae
|
||||||
|
|
||||||
|
if block_state.output_type == "latent":
|
||||||
|
block_state.images = block_state.latents
|
||||||
|
else:
|
||||||
|
latents = block_state.latents
|
||||||
|
latent_ids = block_state.latent_ids
|
||||||
|
|
||||||
|
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||||
|
|
||||||
|
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps).to(
|
||||||
|
latents.device, latents.dtype
|
||||||
|
)
|
||||||
|
latents = latents * latents_bn_std + latents_bn_mean
|
||||||
|
|
||||||
|
latents = self._unpatchify_latents(latents)
|
||||||
|
|
||||||
|
block_state.images = vae.decode(latents, return_dict=False)[0]
|
||||||
|
block_state.images = components.image_processor.postprocess(
|
||||||
|
block_state.images, output_type=block_state.output_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
252
src/diffusers/modular_pipelines/flux2/denoise.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...models import Flux2Transformer2DModel
|
||||||
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
|
from ...utils import is_torch_xla_available, logging
|
||||||
|
from ..modular_pipeline import (
|
||||||
|
BlockState,
|
||||||
|
LoopSequentialPipelineBlocks,
|
||||||
|
ModularPipelineBlocks,
|
||||||
|
PipelineState,
|
||||||
|
)
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_xla_available():
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
|
XLA_AVAILABLE = True
|
||||||
|
else:
|
||||||
|
XLA_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2LoopDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("transformer", Flux2Transformer2DModel)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that denoises the latents for Flux2. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return [
|
||||||
|
InputParam("joint_attention_kwargs"),
|
||||||
|
InputParam(
|
||||||
|
"latents",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The latents to denoise. Shape: (B, seq_len, C)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"image_latent_ids",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Position IDs for image latents. Shape: (B, img_seq_len, 4)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"guidance",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Guidance scale as a tensor",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from Mistral3",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"txt_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs for text tokens (T, H, W, L)",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"latent_ids",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="4D position IDs for latent tokens (T, H, W, L)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
|
||||||
|
) -> PipelineState:
|
||||||
|
latents = block_state.latents
|
||||||
|
latent_model_input = latents.to(components.transformer.dtype)
|
||||||
|
img_ids = block_state.latent_ids
|
||||||
|
|
||||||
|
image_latents = getattr(block_state, "image_latents", None)
|
||||||
|
if image_latents is not None:
|
||||||
|
latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype)
|
||||||
|
image_latent_ids = block_state.image_latent_ids
|
||||||
|
img_ids = torch.cat([img_ids, image_latent_ids], dim=1)
|
||||||
|
|
||||||
|
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
noise_pred = components.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
timestep=timestep / 1000,
|
||||||
|
guidance=block_state.guidance,
|
||||||
|
encoder_hidden_states=block_state.prompt_embeds,
|
||||||
|
txt_ids=block_state.txt_ids,
|
||||||
|
img_ids=img_ids,
|
||||||
|
joint_attention_kwargs=block_state.joint_attention_kwargs,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
noise_pred = noise_pred[:, : latents.size(1)]
|
||||||
|
block_state.noise_pred = noise_pred
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Step within the denoising loop that updates the latents after denoising. "
|
||||||
|
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
|
||||||
|
"object (e.g. `Flux2DenoiseLoopWrapper`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[Tuple[str, Any]]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_inputs(self) -> List[str]:
|
||||||
|
return [InputParam("generator")]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
|
||||||
|
latents_dtype = block_state.latents.dtype
|
||||||
|
block_state.latents = components.scheduler.step(
|
||||||
|
block_state.noise_pred,
|
||||||
|
t,
|
||||||
|
block_state.latents,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if block_state.latents.dtype != latents_dtype:
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
block_state.latents = block_state.latents.to(latents_dtype)
|
||||||
|
|
||||||
|
return components, block_state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Pipeline block that iteratively denoises the latents over `timesteps`. "
|
||||||
|
"The specific steps within each iteration can be customized with `sub_blocks` attribute"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||||
|
ComponentSpec("transformer", Flux2Transformer2DModel),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop_inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam(
|
||||||
|
"timesteps",
|
||||||
|
required=True,
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="The timesteps to use for the denoising process.",
|
||||||
|
),
|
||||||
|
InputParam(
|
||||||
|
"num_inference_steps",
|
||||||
|
required=True,
|
||||||
|
type_hint=int,
|
||||||
|
description="The number of inference steps to use for the denoising process.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.num_warmup_steps = max(
|
||||||
|
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(block_state.timesteps):
|
||||||
|
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||||
|
|
||||||
|
if i == len(block_state.timesteps) - 1 or (
|
||||||
|
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
|
||||||
|
if XLA_AVAILABLE:
|
||||||
|
xm.mark_step()
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2DenoiseStep(Flux2DenoiseLoopWrapper):
|
||||||
|
block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser]
|
||||||
|
block_names = ["denoiser", "after_denoiser"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"Denoise step that iteratively denoises the latents for Flux2. \n"
|
||||||
|
"Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n"
|
||||||
|
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
|
||||||
|
" - `Flux2LoopDenoiser`\n"
|
||||||
|
" - `Flux2LoopAfterDenoiser`\n"
|
||||||
|
"This block supports both text-to-image and image-conditioned generation."
|
||||||
|
)
|
||||||
349
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
349
src/diffusers/modular_pipelines/flux2/encoders.py
Normal file
@@ -0,0 +1,349 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoProcessor, Mistral3ForConditionalGeneration
|
||||||
|
|
||||||
|
from ...models import AutoencoderKLFlux2
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def format_text_input(prompts: List[str], system_message: str = None):
|
||||||
|
"""Format prompts for Mistral3 chat template."""
|
||||||
|
cleaned_txt = [prompt.replace("[IMG]", "") for prompt in prompts]
|
||||||
|
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": system_message}],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": [{"type": "text", "text": prompt}]},
|
||||||
|
]
|
||||||
|
for prompt in cleaned_txt
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||||
|
def retrieve_latents(
|
||||||
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||||
|
):
|
||||||
|
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||||
|
return encoder_output.latent_dist.sample(generator)
|
||||||
|
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||||
|
return encoder_output.latent_dist.mode()
|
||||||
|
elif hasattr(encoder_output, "latents"):
|
||||||
|
return encoder_output.latents
|
||||||
|
else:
|
||||||
|
raise AttributeError("Could not access latents of provided encoder_output")
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
DEFAULT_SYSTEM_MESSAGE = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation."
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generates text embeddings using Mistral3 to guide the image generation"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec("text_encoder", Mistral3ForConditionalGeneration),
|
||||||
|
ComponentSpec("tokenizer", AutoProcessor),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||||
|
InputParam("max_sequence_length", type_hint=int, default=512, required=False),
|
||||||
|
InputParam("text_encoder_out_layers", type_hint=Tuple[int], default=(10, 20, 30), required=False),
|
||||||
|
InputParam("joint_attention_kwargs"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from Mistral3 used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
prompt = block_state.prompt
|
||||||
|
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||||
|
"Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_mistral_3_prompt_embeds(
|
||||||
|
text_encoder: Mistral3ForConditionalGeneration,
|
||||||
|
tokenizer: AutoProcessor,
|
||||||
|
prompt: Union[str, List[str]],
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
max_sequence_length: int = 512,
|
||||||
|
# fmt: off
|
||||||
|
system_message: str = "You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object attribution and actions without speculation.",
|
||||||
|
# fmt: on
|
||||||
|
hidden_states_layers: Tuple[int] = (10, 20, 30),
|
||||||
|
):
|
||||||
|
dtype = text_encoder.dtype if dtype is None else dtype
|
||||||
|
device = text_encoder.device if device is None else device
|
||||||
|
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
messages_batch = format_text_input(prompts=prompt, system_message=system_message)
|
||||||
|
|
||||||
|
inputs = tokenizer.apply_chat_template(
|
||||||
|
messages_batch,
|
||||||
|
add_generation_prompt=False,
|
||||||
|
tokenize=True,
|
||||||
|
return_dict=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="max_length",
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_sequence_length,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(device)
|
||||||
|
|
||||||
|
output = text_encoder(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_hidden_states=True,
|
||||||
|
use_cache=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||||
|
out = out.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
|
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||||
|
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||||
|
|
||||||
|
return prompt_embeds
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
if block_state.prompt_embeds is not None:
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
prompt = block_state.prompt
|
||||||
|
if prompt is None:
|
||||||
|
prompt = ""
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
block_state.prompt_embeds = self._get_mistral_3_prompt_embeds(
|
||||||
|
text_encoder=components.text_encoder,
|
||||||
|
tokenizer=components.tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
device=block_state.device,
|
||||||
|
max_sequence_length=block_state.max_sequence_length,
|
||||||
|
system_message=self.DEFAULT_SYSTEM_MESSAGE,
|
||||||
|
hidden_states_layers=block_state.text_encoder_out_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2RemoteTextEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
REMOTE_URL = "https://remote-text-encoder-flux-2.huggingface.co/predict"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Text Encoder step that generates text embeddings using a remote API endpoint"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("prompt"),
|
||||||
|
InputParam("prompt_embeds", type_hint=torch.Tensor, required=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Text embeddings from remote API used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_inputs(block_state):
|
||||||
|
prompt = block_state.prompt
|
||||||
|
prompt_embeds = getattr(block_state, "prompt_embeds", None)
|
||||||
|
|
||||||
|
if prompt is not None and prompt_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. "
|
||||||
|
"Please make sure to only forward one of the two."
|
||||||
|
)
|
||||||
|
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||||
|
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
import io
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from huggingface_hub import get_token
|
||||||
|
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
self.check_inputs(block_state)
|
||||||
|
|
||||||
|
block_state.device = components._execution_device
|
||||||
|
|
||||||
|
if block_state.prompt_embeds is not None:
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
prompt = block_state.prompt
|
||||||
|
if prompt is None:
|
||||||
|
prompt = ""
|
||||||
|
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
self.REMOTE_URL,
|
||||||
|
json={"prompt": prompt},
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {get_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
block_state.prompt_embeds = torch.load(io.BytesIO(response.content), weights_only=True)
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.to(block_state.device)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2VaeEncoderStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "VAE Encoder step that encodes preprocessed images into latent representations for Flux2."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [ComponentSpec("vae", AutoencoderKLFlux2)]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("condition_images", type_hint=List[torch.Tensor]),
|
||||||
|
InputParam("generator"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"image_latents",
|
||||||
|
type_hint=List[torch.Tensor],
|
||||||
|
description="List of latent representations for each reference image",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _patchify_latents(latents):
|
||||||
|
"""Convert latents to patchified format for Flux2."""
|
||||||
|
batch_size, num_channels_latents, height, width = latents.shape
|
||||||
|
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||||
|
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||||
|
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def _encode_vae_image(self, vae: AutoencoderKLFlux2, image: torch.Tensor, generator: torch.Generator):
|
||||||
|
"""Encode a single image using Flux2 VAE with batch norm normalization."""
|
||||||
|
if image.ndim != 4:
|
||||||
|
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||||
|
|
||||||
|
image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode="argmax")
|
||||||
|
image_latents = self._patchify_latents(image_latents)
|
||||||
|
|
||||||
|
latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||||
|
latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
|
||||||
|
latents_bn_std = latents_bn_std.to(image_latents.device, image_latents.dtype)
|
||||||
|
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||||
|
|
||||||
|
return image_latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
condition_images = block_state.condition_images
|
||||||
|
|
||||||
|
if condition_images is None:
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
device = components._execution_device
|
||||||
|
dtype = components.vae.dtype
|
||||||
|
|
||||||
|
image_latents = []
|
||||||
|
for image in condition_images:
|
||||||
|
image = image.to(device=device, dtype=dtype)
|
||||||
|
latent = self._encode_vae_image(
|
||||||
|
vae=components.vae,
|
||||||
|
image=image,
|
||||||
|
generator=block_state.generator,
|
||||||
|
)
|
||||||
|
image_latents.append(latent)
|
||||||
|
|
||||||
|
block_state.image_latents = image_latents
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
160
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
160
src/diffusers/modular_pipelines/flux2/inputs.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ...configuration_utils import FrozenDict
|
||||||
|
from ...pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||||
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||||
|
from .modular_pipeline import Flux2ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2TextInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return (
|
||||||
|
"This step:\n"
|
||||||
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||||
|
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("num_images_per_prompt", default=1),
|
||||||
|
InputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
required=True,
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
description="Pre-generated text embeddings from Mistral3. Can be generated from text_encoder step.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[str]:
|
||||||
|
return [
|
||||||
|
OutputParam(
|
||||||
|
"batch_size",
|
||||||
|
type_hint=int,
|
||||||
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"dtype",
|
||||||
|
type_hint=torch.dtype,
|
||||||
|
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
|
||||||
|
),
|
||||||
|
OutputParam(
|
||||||
|
"prompt_embeds",
|
||||||
|
type_hint=torch.Tensor,
|
||||||
|
kwargs_type="denoiser_input_fields",
|
||||||
|
description="Text embeddings used to guide the image generation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState:
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
|
||||||
|
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||||
|
block_state.dtype = block_state.prompt_embeds.dtype
|
||||||
|
|
||||||
|
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
|
||||||
|
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||||
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ProcessImagesInputStep(ModularPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "Image preprocess step for Flux2. Validates and preprocesses reference images."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_components(self) -> List[ComponentSpec]:
|
||||||
|
return [
|
||||||
|
ComponentSpec(
|
||||||
|
"image_processor",
|
||||||
|
Flux2ImageProcessor,
|
||||||
|
config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 32}),
|
||||||
|
default_creation_method="from_config",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> List[InputParam]:
|
||||||
|
return [
|
||||||
|
InputParam("image"),
|
||||||
|
InputParam("height"),
|
||||||
|
InputParam("width"),
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def intermediate_outputs(self) -> List[OutputParam]:
|
||||||
|
return [OutputParam(name="condition_images", type_hint=List[torch.Tensor])]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(self, components: Flux2ModularPipeline, state: PipelineState):
|
||||||
|
block_state = self.get_block_state(state)
|
||||||
|
images = block_state.image
|
||||||
|
|
||||||
|
if images is None:
|
||||||
|
block_state.condition_images = None
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
|
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
condition_images = []
|
||||||
|
for img in images:
|
||||||
|
components.image_processor.check_image_input(img)
|
||||||
|
|
||||||
|
image_width, image_height = img.size
|
||||||
|
if image_width * image_height > 1024 * 1024:
|
||||||
|
img = components.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||||
|
image_width, image_height = img.size
|
||||||
|
|
||||||
|
multiple_of = components.vae_scale_factor * 2
|
||||||
|
image_width = (image_width // multiple_of) * multiple_of
|
||||||
|
image_height = (image_height // multiple_of) * multiple_of
|
||||||
|
condition_img = components.image_processor.preprocess(
|
||||||
|
img, height=image_height, width=image_width, resize_mode="crop"
|
||||||
|
)
|
||||||
|
condition_images.append(condition_img)
|
||||||
|
|
||||||
|
if block_state.height is None:
|
||||||
|
block_state.height = image_height
|
||||||
|
if block_state.width is None:
|
||||||
|
block_state.width = image_width
|
||||||
|
|
||||||
|
block_state.condition_images = condition_images
|
||||||
|
|
||||||
|
self.set_block_state(state, block_state)
|
||||||
|
return components, state
|
||||||
166
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
166
src/diffusers/modular_pipelines/flux2/modular_blocks.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||||
|
from ..modular_pipeline_utils import InsertableDict
|
||||||
|
from .before_denoise import (
|
||||||
|
Flux2PrepareImageLatentsStep,
|
||||||
|
Flux2PrepareLatentsStep,
|
||||||
|
Flux2RoPEInputsStep,
|
||||||
|
Flux2SetTimestepsStep,
|
||||||
|
)
|
||||||
|
from .decoders import Flux2DecodeStep
|
||||||
|
from .denoise import Flux2DenoiseStep
|
||||||
|
from .encoders import (
|
||||||
|
Flux2RemoteTextEncoderStep,
|
||||||
|
Flux2TextEncoderStep,
|
||||||
|
Flux2VaeEncoderStep,
|
||||||
|
)
|
||||||
|
from .inputs import (
|
||||||
|
Flux2ProcessImagesInputStep,
|
||||||
|
Flux2TextInputStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
Flux2VaeEncoderBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("preprocess", Flux2ProcessImagesInputStep()),
|
||||||
|
("encode", Flux2VaeEncoderStep()),
|
||||||
|
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = Flux2VaeEncoderBlocks.values()
|
||||||
|
block_names = Flux2VaeEncoderBlocks.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||||
|
block_classes = [Flux2VaeEncoderSequentialStep]
|
||||||
|
block_names = ["img_conditioning"]
|
||||||
|
block_trigger_inputs = ["image"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||||
|
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||||
|
" - `Flux2VaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||||
|
" - If `image` is not provided, step will be skipped."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Flux2BeforeDenoiseBlocks = InsertableDict(
|
||||||
|
[
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2BeforeDenoiseStep(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = Flux2BeforeDenoiseBlocks.values()
|
||||||
|
block_names = Flux2BeforeDenoiseBlocks.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return "Before denoise step that prepares the inputs for the denoise step in Flux2 generation."
|
||||||
|
|
||||||
|
|
||||||
|
AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||||
|
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("vae_image_encoder", Flux2AutoVaeEncoderStep()),
|
||||||
|
("before_denoise", Flux2BeforeDenoiseStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||||
|
model_name = "flux2"
|
||||||
|
|
||||||
|
block_classes = AUTO_BLOCKS.values()
|
||||||
|
block_names = AUTO_BLOCKS.keys()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
|
||||||
|
"- For text-to-image generation, all you need to provide is `prompt`.\n"
|
||||||
|
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||||
|
[
|
||||||
|
("text_encoder", Flux2TextEncoderStep()),
|
||||||
|
("text_input", Flux2TextInputStep()),
|
||||||
|
("preprocess_images", Flux2ProcessImagesInputStep()),
|
||||||
|
("vae_encoder", Flux2VaeEncoderStep()),
|
||||||
|
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||||
|
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||||
|
("set_timesteps", Flux2SetTimestepsStep()),
|
||||||
|
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||||
|
("denoise", Flux2DenoiseStep()),
|
||||||
|
("decode", Flux2DecodeStep()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ALL_BLOCKS = {
|
||||||
|
"text2image": TEXT2IMAGE_BLOCKS,
|
||||||
|
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
|
||||||
|
"auto": AUTO_BLOCKS,
|
||||||
|
"remote": REMOTE_AUTO_BLOCKS,
|
||||||
|
}
|
||||||
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
57
src/diffusers/modular_pipelines/flux2/modular_pipeline.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from ...loaders import Flux2LoraLoaderMixin
|
||||||
|
from ...utils import logging
|
||||||
|
from ..modular_pipeline import ModularPipeline
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ModularPipeline(ModularPipeline, Flux2LoraLoaderMixin):
|
||||||
|
"""
|
||||||
|
A ModularPipeline for Flux2.
|
||||||
|
|
||||||
|
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_blocks_name = "Flux2AutoBlocks"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_height(self):
|
||||||
|
return self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_width(self):
|
||||||
|
return self.default_sample_size * self.vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_sample_size(self):
|
||||||
|
return 128
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vae_scale_factor(self):
|
||||||
|
vae_scale_factor = 8
|
||||||
|
if getattr(self, "vae", None) is not None:
|
||||||
|
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||||
|
return vae_scale_factor
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_channels_latents(self):
|
||||||
|
num_channels_latents = 32
|
||||||
|
if getattr(self, "transformer", None):
|
||||||
|
num_channels_latents = self.transformer.config.in_channels // 4
|
||||||
|
return num_channels_latents
|
||||||
@@ -58,6 +58,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
|||||||
("wan", "WanModularPipeline"),
|
("wan", "WanModularPipeline"),
|
||||||
("flux", "FluxModularPipeline"),
|
("flux", "FluxModularPipeline"),
|
||||||
("flux-kontext", "FluxKontextModularPipeline"),
|
("flux-kontext", "FluxKontextModularPipeline"),
|
||||||
|
("flux2", "Flux2ModularPipeline"),
|
||||||
("qwenimage", "QwenImageModularPipeline"),
|
("qwenimage", "QwenImageModularPipeline"),
|
||||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||||
@@ -1585,7 +1586,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
for name, config_spec in self._config_specs.items():
|
for name, config_spec in self._config_specs.items():
|
||||||
default_configs[name] = config_spec.default
|
default_configs[name] = config_spec.default
|
||||||
self.register_to_config(**default_configs)
|
self.register_to_config(**default_configs)
|
||||||
|
|
||||||
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -2,6 +2,36 @@
|
|||||||
from ..utils import DummyObject, requires_backends
|
from ..utils import DummyObject, requires_backends
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2AutoBlocks(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
class Flux2ModularPipeline(metaclass=DummyObject):
|
||||||
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch", "transformers"])
|
||||||
|
|
||||||
|
|
||||||
class FluxAutoBlocks(metaclass=DummyObject):
|
class FluxAutoBlocks(metaclass=DummyObject):
|
||||||
_backends = ["torch", "transformers"]
|
_backends = ["torch", "transformers"]
|
||||||
|
|
||||||
|
|||||||
0
tests/modular_pipelines/flux2/__init__.py
Normal file
0
tests/modular_pipelines/flux2/__init__.py
Normal file
93
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
93
tests/modular_pipelines/flux2/test_modular_pipeline_flux2.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from diffusers.modular_pipelines import (
|
||||||
|
Flux2AutoBlocks,
|
||||||
|
Flux2ModularPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ...testing_utils import floats_tensor, torch_device
|
||||||
|
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
|
pipeline_class = Flux2ModularPipeline
|
||||||
|
pipeline_blocks_class = Flux2AutoBlocks
|
||||||
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||||
|
|
||||||
|
params = frozenset(["prompt", "height", "width", "guidance_scale"])
|
||||||
|
batch_params = frozenset(["prompt"])
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, seed=0):
|
||||||
|
generator = self.get_generator(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||||
|
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||||
|
"text_encoder_out_layers": (1,),
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 4.0,
|
||||||
|
"height": 32,
|
||||||
|
"width": 32,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_float16_inference(self):
|
||||||
|
super().test_float16_inference(9e-2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlux2ImageConditionedModularPipelineFast(ModularPipelineTesterMixin):
|
||||||
|
pipeline_class = Flux2ModularPipeline
|
||||||
|
pipeline_blocks_class = Flux2AutoBlocks
|
||||||
|
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-modular"
|
||||||
|
|
||||||
|
params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
|
||||||
|
batch_params = frozenset(["prompt", "image"])
|
||||||
|
|
||||||
|
def get_dummy_inputs(self, seed=0):
|
||||||
|
generator = self.get_generator(seed)
|
||||||
|
inputs = {
|
||||||
|
"prompt": "A painting of a squirrel eating a burger",
|
||||||
|
# TODO (Dhruv): Update text encoder config so that vocab_size matches tokenizer
|
||||||
|
"max_sequence_length": 8, # bit of a hack to workaround vocab size mismatch
|
||||||
|
"text_encoder_out_layers": (1,),
|
||||||
|
"generator": generator,
|
||||||
|
"num_inference_steps": 2,
|
||||||
|
"guidance_scale": 4.0,
|
||||||
|
"height": 32,
|
||||||
|
"width": 32,
|
||||||
|
"output_type": "pt",
|
||||||
|
}
|
||||||
|
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(torch_device)
|
||||||
|
image = image.cpu().permute(0, 2, 3, 1)[0]
|
||||||
|
init_image = PIL.Image.fromarray(np.uint8(image * 255)).convert("RGB")
|
||||||
|
inputs["image"] = init_image
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def test_float16_inference(self):
|
||||||
|
super().test_float16_inference(9e-2)
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="batched inference is currently not supported")
|
||||||
|
def test_inference_batch_single_identical(self, batch_size=2, expected_max_diff=0.0001):
|
||||||
|
return
|
||||||
@@ -165,7 +165,6 @@ class ModularPipelineTesterMixin:
|
|||||||
expected_max_diff=1e-4,
|
expected_max_diff=1e-4,
|
||||||
):
|
):
|
||||||
pipe = self.get_pipeline().to(torch_device)
|
pipe = self.get_pipeline().to(torch_device)
|
||||||
|
|
||||||
inputs = self.get_dummy_inputs()
|
inputs = self.get_dummy_inputs()
|
||||||
|
|
||||||
# Reset generator in case it is has been used in self.get_dummy_inputs
|
# Reset generator in case it is has been used in self.get_dummy_inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user