mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-13 20:17:53 +08:00
Compare commits
10 Commits
group-offl
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
764f7ede33 | ||
|
|
8d0f3e1ba8 | ||
|
|
094caf398f | ||
|
|
81c354d879 | ||
|
|
0a2c26d0a4 | ||
|
|
07c5ba8eee | ||
|
|
897aed72fa | ||
|
|
07a63e197e | ||
|
|
068c6ef6c1 | ||
|
|
94bcb8941e |
3
.github/workflows/pr_tests_gpu.yml
vendored
3
.github/workflows/pr_tests_gpu.yml
vendored
@@ -1,5 +1,8 @@
|
||||
name: Fast GPU Tests on PR
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: main
|
||||
|
||||
@@ -532,8 +532,6 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -677,6 +675,8 @@
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
|
||||
@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
|
||||
## Flux2Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput
|
||||
|
||||
@@ -21,29 +21,31 @@
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
## Basic usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
image=None,
|
||||
video=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=93,
|
||||
generator=torch.Generator().manual_seed(1),
|
||||
).frames[0]
|
||||
export_to_video(output, "text2world.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
|
||||
## Flux2KleinPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Flux2KleinKVPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinKVPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -44,6 +44,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cosmos](cosmos) | text2video, video2video |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
| [DDPM](ddpm) | unconditional image generation |
|
||||
|
||||
@@ -434,6 +434,12 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
"FluxModularPipeline",
|
||||
"HeliosAutoBlocks",
|
||||
"HeliosModularPipeline",
|
||||
"HeliosPyramidAutoBlocks",
|
||||
"HeliosPyramidDistilledAutoBlocks",
|
||||
"HeliosPyramidDistilledModularPipeline",
|
||||
"HeliosPyramidModularPipeline",
|
||||
"QwenImageAutoBlocks",
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditModularPipeline",
|
||||
@@ -504,6 +510,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2KleinKVPipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
@@ -1188,6 +1195,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
HeliosAutoBlocks,
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidAutoBlocks,
|
||||
HeliosPyramidDistilledAutoBlocks,
|
||||
HeliosPyramidDistilledModularPipeline,
|
||||
HeliosPyramidModularPipeline,
|
||||
QwenImageAutoBlocks,
|
||||
QwenImageEditAutoBlocks,
|
||||
QwenImageEditModularPipeline,
|
||||
@@ -1254,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
|
||||
@@ -2538,8 +2538,12 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
alpha_tensor = state_dict.pop(alpha_key, None)
|
||||
if alpha_tensor is None:
|
||||
return 1.0, 1.0
|
||||
scale = (
|
||||
alpha_tensor.item() / rank
|
||||
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
|
||||
@@ -60,6 +60,16 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
ulysses_anything (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
|
||||
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
|
||||
`ring_degree` must be 1.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -68,6 +78,7 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
@@ -124,7 +135,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -21,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -32,7 +33,6 @@ from ..embeddings import (
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
@@ -40,6 +40,216 @@ from ..normalization import AdaLayerNormContinuous
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Flux2Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_cache: "Flux2KVCache | None" = None
|
||||
|
||||
|
||||
class Flux2KVLayerCache:
|
||||
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
|
||||
|
||||
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
|
||||
format: (batch_size, num_ref_tokens, num_heads, head_dim).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_ref: torch.Tensor | None = None
|
||||
self.v_ref: torch.Tensor | None = None
|
||||
|
||||
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
|
||||
"""Store reference token K/V."""
|
||||
self.k_ref = k_ref
|
||||
self.v_ref = v_ref
|
||||
|
||||
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Retrieve cached reference token K/V."""
|
||||
if self.k_ref is None:
|
||||
raise RuntimeError("KV cache has not been populated yet.")
|
||||
return self.k_ref, self.v_ref
|
||||
|
||||
def clear(self):
|
||||
self.k_ref = None
|
||||
self.v_ref = None
|
||||
|
||||
|
||||
class Flux2KVCache:
|
||||
"""Container for all layers' reference-token KV caches.
|
||||
|
||||
Holds separate cache lists for double-stream and single-stream transformer blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_double_layers: int, num_single_layers: int):
|
||||
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
|
||||
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
|
||||
self.num_ref_tokens: int = 0
|
||||
|
||||
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.double_block_caches[layer_idx]
|
||||
|
||||
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.single_block_caches[layer_idx]
|
||||
|
||||
def clear(self):
|
||||
for cache in self.double_block_caches:
|
||||
cache.clear()
|
||||
for cache in self.single_block_caches:
|
||||
cache.clear()
|
||||
self.num_ref_tokens = 0
|
||||
|
||||
|
||||
def _flux2_kv_causal_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_txt_tokens: int,
|
||||
num_ref_tokens: int,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
backend=None,
|
||||
) -> torch.Tensor:
|
||||
"""Causal attention for KV caching where reference tokens only self-attend.
|
||||
|
||||
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
|
||||
|
||||
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
|
||||
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
|
||||
between txt and img.
|
||||
"""
|
||||
# No ref tokens and no cache — standard full attention
|
||||
if num_ref_tokens == 0 and kv_cache is None:
|
||||
return dispatch_attention_fn(query, key, value, backend=backend)
|
||||
|
||||
if kv_cache is not None:
|
||||
# Cached mode: inject ref K/V between txt and img
|
||||
k_ref, v_ref = kv_cache.get()
|
||||
|
||||
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
|
||||
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
|
||||
|
||||
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
|
||||
|
||||
# Extract mode: ref tokens self-attend, txt+img attend to all
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
|
||||
q_txt = query[:, :ref_start]
|
||||
q_ref = query[:, ref_start:ref_end]
|
||||
q_img = query[:, ref_end:]
|
||||
|
||||
k_txt = key[:, :ref_start]
|
||||
k_ref = key[:, ref_start:ref_end]
|
||||
k_img = key[:, ref_end:]
|
||||
|
||||
v_txt = value[:, :ref_start]
|
||||
v_ref = value[:, ref_start:ref_end]
|
||||
v_img = value[:, ref_end:]
|
||||
|
||||
# txt+img attend to all tokens
|
||||
q_txt_img = torch.cat([q_txt, q_img], dim=1)
|
||||
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
|
||||
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
|
||||
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
|
||||
attn_txt = attn_txt_img[:, :ref_start]
|
||||
attn_img = attn_txt_img[:, ref_start:]
|
||||
|
||||
# ref tokens self-attend only
|
||||
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
|
||||
|
||||
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
|
||||
|
||||
|
||||
def _blend_mod_params(
|
||||
img_params: tuple[torch.Tensor, ...],
|
||||
ref_params: tuple[torch.Tensor, ...],
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return tuple(blended)
|
||||
|
||||
|
||||
def _blend_double_block_mods(
|
||||
img_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
|
||||
with `Flux2Modulation.split(mod, 2)`.
|
||||
"""
|
||||
if img_mod.ndim == 2:
|
||||
img_mod = img_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_chunks = torch.chunk(img_mod, 6, dim=-1)
|
||||
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
|
||||
img_mods = (img_chunks[0:3], img_chunks[3:6])
|
||||
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
|
||||
|
||||
all_params = []
|
||||
for img_set, ref_set in zip(img_mods, ref_mods):
|
||||
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
|
||||
all_params.extend(blended)
|
||||
return torch.cat(all_params, dim=-1)
|
||||
|
||||
|
||||
def _blend_single_block_mods(
|
||||
single_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_txt: int,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
|
||||
"""
|
||||
if single_mod.ndim == 2:
|
||||
single_mod = single_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_params = torch.chunk(single_mod, 3, dim=-1)
|
||||
ref_params = torch.chunk(ref_mod, 3, dim=-1)
|
||||
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
im_expanded = im.expand(B, seq_len, -1)
|
||||
rm_expanded = rm.expand(B, num_ref, -1)
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return torch.cat(blended, dim=-1)
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -181,9 +391,108 @@ class Flux2AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
|
||||
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
|
||||
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -312,6 +621,90 @@ class Flux2ParallelSelfAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVParallelSelfAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
|
||||
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
|
||||
identically to `Flux2ParallelSelfAttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_txt_tokens: int = 0,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
attn_output = attn_output.flatten(2, 3)
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
@@ -322,7 +715,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
@@ -780,6 +1173,8 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
_skip_keys = ["kv_cache"]
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
@@ -791,19 +1186,21 @@ class Flux2Transformer2DModel(
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
kv_cache: "Flux2KVCache | None" = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
ref_fixed_timestep: float = 0.0,
|
||||
) -> torch.Tensor | Flux2Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`Flux2Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -811,13 +1208,23 @@ class Flux2Transformer2DModel(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
|
||||
returned. When "cached", the provided cache is used to inject ref K/V during attention.
|
||||
kv_cache_mode (`str`, *optional*):
|
||||
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
|
||||
`None`, standard forward pass without KV caching.
|
||||
num_ref_tokens (`int`, defaults to `0`):
|
||||
Number of reference image tokens prepended to `hidden_states` (only used when
|
||||
`kv_cache_mode="extract"`).
|
||||
ref_fixed_timestep (`float`, defaults to `0.0`):
|
||||
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
|
||||
populated `Flux2KVCache`.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
@@ -832,13 +1239,33 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
|
||||
# KV extract mode: create cache and blend modulations for ref tokens
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
num_img_tokens = hidden_states.shape[1] # includes ref tokens
|
||||
|
||||
kv_cache = Flux2KVCache(
|
||||
num_double_layers=len(self.transformer_blocks),
|
||||
num_single_layers=len(self.single_transformer_blocks),
|
||||
)
|
||||
kv_cache.num_ref_tokens = num_ref_tokens
|
||||
|
||||
# Ref tokens use a fixed timestep for modulation
|
||||
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
|
||||
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
|
||||
|
||||
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
|
||||
ref_single_mod = self.single_stream_modulation(ref_temb)
|
||||
|
||||
# Blend double block img modulation: [ref_mod, img_mod]
|
||||
double_stream_mod_img = _blend_double_block_mods(
|
||||
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
|
||||
)
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
@@ -851,8 +1278,29 @@ class Flux2Transformer2DModel(
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
# 4. Build joint_attention_kwargs with KV cache info
|
||||
if kv_cache_mode == "extract":
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "extract",
|
||||
"num_ref_tokens": num_ref_tokens,
|
||||
}
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "cached",
|
||||
"num_ref_tokens": kv_cache.num_ref_tokens,
|
||||
}
|
||||
else:
|
||||
kv_attn_kwargs = joint_attention_kwargs
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -861,7 +1309,7 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
@@ -870,13 +1318,30 @@ class Flux2Transformer2DModel(
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs,
|
||||
)
|
||||
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
total_single_len = hidden_states.shape[1]
|
||||
single_stream_mod = _blend_single_block_mods(
|
||||
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
|
||||
)
|
||||
|
||||
# Build single-block KV kwargs (single blocks need num_txt_tokens)
|
||||
if kv_cache_mode is not None:
|
||||
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
|
||||
else:
|
||||
kv_attn_kwargs_single = kv_attn_kwargs
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -884,7 +1349,7 @@ class Flux2Transformer2DModel(
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs_single,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -892,16 +1357,25 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states=None,
|
||||
temb_mod=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs_single,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 6. Output layers
|
||||
# Remove text tokens (and ref tokens in extract mode) from concatenated stream
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
|
||||
else:
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 7. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if kv_cache_mode == "extract":
|
||||
if not return_dict:
|
||||
return (output, kv_cache)
|
||||
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return Flux2Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -56,6 +56,14 @@ else:
|
||||
"WanImage2VideoModularPipeline",
|
||||
"Wan22Image2VideoModularPipeline",
|
||||
]
|
||||
_import_structure["helios"] = [
|
||||
"HeliosAutoBlocks",
|
||||
"HeliosModularPipeline",
|
||||
"HeliosPyramidAutoBlocks",
|
||||
"HeliosPyramidDistilledAutoBlocks",
|
||||
"HeliosPyramidDistilledModularPipeline",
|
||||
"HeliosPyramidModularPipeline",
|
||||
]
|
||||
_import_structure["flux"] = [
|
||||
"FluxAutoBlocks",
|
||||
"FluxModularPipeline",
|
||||
@@ -103,6 +111,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2KleinModularPipeline,
|
||||
Flux2ModularPipeline,
|
||||
)
|
||||
from .helios import (
|
||||
HeliosAutoBlocks,
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidAutoBlocks,
|
||||
HeliosPyramidDistilledAutoBlocks,
|
||||
HeliosPyramidDistilledModularPipeline,
|
||||
HeliosPyramidModularPipeline,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
AutoPipelineBlocks,
|
||||
BlockState,
|
||||
|
||||
59
src/diffusers/modular_pipelines/helios/__init__.py
Normal file
59
src/diffusers/modular_pipelines/helios/__init__.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_helios"] = ["HeliosAutoBlocks"]
|
||||
_import_structure["modular_blocks_helios_pyramid"] = ["HeliosPyramidAutoBlocks"]
|
||||
_import_structure["modular_blocks_helios_pyramid_distilled"] = ["HeliosPyramidDistilledAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"HeliosModularPipeline",
|
||||
"HeliosPyramidDistilledModularPipeline",
|
||||
"HeliosPyramidModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_helios import HeliosAutoBlocks
|
||||
from .modular_blocks_helios_pyramid import HeliosPyramidAutoBlocks
|
||||
from .modular_blocks_helios_pyramid_distilled import HeliosPyramidDistilledAutoBlocks
|
||||
from .modular_pipeline import (
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidDistilledModularPipeline,
|
||||
HeliosPyramidModularPipeline,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
836
src/diffusers/modular_pipelines/helios/before_denoise.py
Normal file
836
src/diffusers/modular_pipelines/helios/before_denoise.py
Normal file
@@ -0,0 +1,836 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models import HeliosTransformer3DModel
|
||||
from ...schedulers import HeliosScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import HeliosModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
class HeliosTextInputStep(ModularPipelineBlocks):
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that:\n"
|
||||
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
|
||||
" 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
|
||||
"All input tensors are expected to have either batch_size=1 or match the batch_size\n"
|
||||
"of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
|
||||
"have a final batch_size of batch_size * num_videos_per_prompt."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"num_videos_per_prompt",
|
||||
default=1,
|
||||
type_hint=int,
|
||||
description="Number of videos to generate per prompt.",
|
||||
),
|
||||
InputParam.template("prompt_embeds"),
|
||||
InputParam.template("negative_prompt_embeds"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[str]:
|
||||
return [
|
||||
OutputParam(
|
||||
"batch_size",
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
|
||||
),
|
||||
OutputParam(
|
||||
"dtype",
|
||||
type_hint=torch.dtype,
|
||||
description="Data type of model tensor inputs (determined by `prompt_embeds.dtype`)",
|
||||
),
|
||||
]
|
||||
|
||||
def check_inputs(self, components, block_state):
|
||||
if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
|
||||
if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {block_state.negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(components, block_state)
|
||||
|
||||
block_state.batch_size = block_state.prompt_embeds.shape[0]
|
||||
block_state.dtype = block_state.prompt_embeds.dtype
|
||||
|
||||
_, seq_len, _ = block_state.prompt_embeds.shape
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
|
||||
block_state.prompt_embeds = block_state.prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if block_state.negative_prompt_embeds is not None:
|
||||
_, seq_len, _ = block_state.negative_prompt_embeds.shape
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
|
||||
1, block_state.num_videos_per_prompt, 1
|
||||
)
|
||||
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
|
||||
block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
# Copied from diffusers.modular_pipelines.wan.before_denoise.repeat_tensor_to_batch_size
|
||||
def repeat_tensor_to_batch_size(
|
||||
input_name: str,
|
||||
input_tensor: torch.Tensor,
|
||||
batch_size: int,
|
||||
num_videos_per_prompt: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""Repeat tensor elements to match the final batch size.
|
||||
|
||||
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
|
||||
by repeating each element along dimension 0.
|
||||
|
||||
The input tensor must have batch size 1 or batch_size. The function will:
|
||||
- If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
|
||||
- If batch size equals batch_size: repeat each element num_videos_per_prompt times
|
||||
|
||||
Args:
|
||||
input_name (str): Name of the input tensor (used for error messages)
|
||||
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
|
||||
batch_size (int): The base batch size (number of prompts)
|
||||
num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
|
||||
|
||||
Raises:
|
||||
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
|
||||
|
||||
Examples:
|
||||
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
|
||||
batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
|
||||
[4, 3]
|
||||
|
||||
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
|
||||
tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
|
||||
- shape: [4, 3]
|
||||
"""
|
||||
# make sure input is a tensor
|
||||
if not isinstance(input_tensor, torch.Tensor):
|
||||
raise ValueError(f"`{input_name}` must be a tensor")
|
||||
|
||||
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
|
||||
if input_tensor.shape[0] == 1:
|
||||
repeat_by = batch_size * num_videos_per_prompt
|
||||
elif input_tensor.shape[0] == batch_size:
|
||||
repeat_by = num_videos_per_prompt
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
|
||||
)
|
||||
|
||||
# expand the tensor to match the batch_size * num_videos_per_prompt
|
||||
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
# Copied from diffusers.modular_pipelines.wan.before_denoise.calculate_dimension_from_latents
|
||||
def calculate_dimension_from_latents(
|
||||
latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
|
||||
) -> tuple[int, int]:
|
||||
"""Calculate image dimensions from latent tensor dimensions.
|
||||
|
||||
This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
|
||||
multiplying the latent num_frames/height/width by the VAE scale factor.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
|
||||
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
|
||||
vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
|
||||
Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
|
||||
vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
|
||||
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: The calculated image dimensions as (height, width)
|
||||
|
||||
Raises:
|
||||
ValueError: If latents tensor doesn't have 4 or 5 dimensions
|
||||
|
||||
"""
|
||||
if latents.ndim != 5:
|
||||
raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
|
||||
|
||||
_, _, num_latent_frames, latent_height, latent_width = latents.shape
|
||||
|
||||
num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
|
||||
height = latent_height * vae_scale_factor_spatial
|
||||
width = latent_width * vae_scale_factor_spatial
|
||||
|
||||
return num_frames, height, width
|
||||
|
||||
|
||||
class HeliosAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Configurable step that standardizes inputs for the denoising step.
|
||||
|
||||
This step handles:
|
||||
1. For encoded image latents: Computes height/width from latents and expands batch size
|
||||
2. For additional_batch_inputs: Expands batch dimensions to match final batch size
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: list[InputParam] | None = None,
|
||||
additional_batch_inputs: list[InputParam] | None = None,
|
||||
):
|
||||
if image_latent_inputs is None:
|
||||
image_latent_inputs = [InputParam.template("image_latents")]
|
||||
if additional_batch_inputs is None:
|
||||
additional_batch_inputs = []
|
||||
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
raise ValueError(f"image_latent_inputs must be a list, but got {type(image_latent_inputs)}")
|
||||
else:
|
||||
for input_param in image_latent_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(f"image_latent_inputs must be a list of InputParam, but got {type(input_param)}")
|
||||
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
raise ValueError(f"additional_batch_inputs must be a list, but got {type(additional_batch_inputs)}")
|
||||
else:
|
||||
for input_param in additional_batch_inputs:
|
||||
if not isinstance(input_param, InputParam):
|
||||
raise ValueError(
|
||||
f"additional_batch_inputs must be a list of InputParam, but got {type(input_param)}"
|
||||
)
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Computes height/width from latents and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {[p.name for p in self._image_latent_inputs]}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {[p.name for p in self._additional_batch_inputs]}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_videos_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
]
|
||||
inputs += self._image_latent_inputs + self._additional_batch_inputs
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
outputs = [
|
||||
OutputParam("height", type_hint=int),
|
||||
OutputParam("width", type_hint=int),
|
||||
]
|
||||
|
||||
for input_param in self._image_latent_inputs:
|
||||
outputs.append(OutputParam(input_param.name, type_hint=torch.Tensor))
|
||||
|
||||
for input_param in self._additional_batch_inputs:
|
||||
outputs.append(OutputParam(input_param.name, type_hint=torch.Tensor))
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
for input_param in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, input_param.name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# Calculate height/width from latents
|
||||
_, height, width = calculate_dimension_from_latents(
|
||||
image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
|
||||
)
|
||||
block_state.height = height
|
||||
block_state.width = width
|
||||
|
||||
# Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_param.name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_param.name, image_latent_tensor)
|
||||
|
||||
for input_param in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_param.name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_param.name,
|
||||
input_tensor=input_tensor,
|
||||
num_videos_per_prompt=block_state.num_videos_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_param.name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosAddNoiseToImageLatentsStep(ModularPipelineBlocks):
|
||||
"""Adds noise to image_latents and fake_image_latents for I2V conditioning.
|
||||
|
||||
Applies single-sigma noise to image_latents (using image_noise_sigma range) and single-sigma noise to
|
||||
fake_image_latents (using video_noise_sigma range).
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Adds noise to image_latents and fake_image_latents for I2V conditioning. "
|
||||
"Uses random sigma from configured ranges for each."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("image_latents"),
|
||||
InputParam(
|
||||
"fake_image_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Fake image latents used as history seed for I2V generation.",
|
||||
),
|
||||
InputParam(
|
||||
"image_noise_sigma_min",
|
||||
default=0.111,
|
||||
type_hint=float,
|
||||
description="Minimum sigma for image latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"image_noise_sigma_max",
|
||||
default=0.135,
|
||||
type_hint=float,
|
||||
description="Maximum sigma for image latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"video_noise_sigma_min",
|
||||
default=0.111,
|
||||
type_hint=float,
|
||||
description="Minimum sigma for video/fake-image latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"video_noise_sigma_max",
|
||||
default=0.135,
|
||||
type_hint=float,
|
||||
description="Maximum sigma for video/fake-image latent noise.",
|
||||
),
|
||||
InputParam.template("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("image_latents"),
|
||||
OutputParam("fake_image_latents", type_hint=torch.Tensor, description="Noisy fake image latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
image_latents = block_state.image_latents
|
||||
fake_image_latents = block_state.fake_image_latents
|
||||
|
||||
# Add noise to image_latents
|
||||
image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=block_state.generator)
|
||||
* (block_state.image_noise_sigma_max - block_state.image_noise_sigma_min)
|
||||
+ block_state.image_noise_sigma_min
|
||||
)
|
||||
image_latents = (
|
||||
image_noise_sigma * randn_tensor(image_latents.shape, generator=block_state.generator, device=device)
|
||||
+ (1 - image_noise_sigma) * image_latents
|
||||
)
|
||||
|
||||
# Add noise to fake_image_latents
|
||||
fake_image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=block_state.generator)
|
||||
* (block_state.video_noise_sigma_max - block_state.video_noise_sigma_min)
|
||||
+ block_state.video_noise_sigma_min
|
||||
)
|
||||
fake_image_latents = (
|
||||
fake_image_noise_sigma
|
||||
* randn_tensor(fake_image_latents.shape, generator=block_state.generator, device=device)
|
||||
+ (1 - fake_image_noise_sigma) * fake_image_latents
|
||||
)
|
||||
|
||||
block_state.image_latents = image_latents.to(device=device, dtype=torch.float32)
|
||||
block_state.fake_image_latents = fake_image_latents.to(device=device, dtype=torch.float32)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosAddNoiseToVideoLatentsStep(ModularPipelineBlocks):
|
||||
"""Adds noise to image_latents and video_latents for V2V conditioning.
|
||||
|
||||
Applies single-sigma noise to image_latents (using image_noise_sigma range) and per-frame noise to video_latents in
|
||||
chunks (using video_noise_sigma range).
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Adds noise to image_latents and video_latents for V2V conditioning. "
|
||||
"Uses single-sigma noise for image_latents and per-frame noise for video chunks."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("image_latents"),
|
||||
InputParam(
|
||||
"video_latents",
|
||||
required=True,
|
||||
type_hint=torch.Tensor,
|
||||
description="Encoded video latents for V2V generation.",
|
||||
),
|
||||
InputParam(
|
||||
"num_latent_frames_per_chunk",
|
||||
default=9,
|
||||
type_hint=int,
|
||||
description="Number of latent frames per temporal chunk.",
|
||||
),
|
||||
InputParam(
|
||||
"image_noise_sigma_min",
|
||||
default=0.111,
|
||||
type_hint=float,
|
||||
description="Minimum sigma for image latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"image_noise_sigma_max",
|
||||
default=0.135,
|
||||
type_hint=float,
|
||||
description="Maximum sigma for image latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"video_noise_sigma_min",
|
||||
default=0.111,
|
||||
type_hint=float,
|
||||
description="Minimum sigma for video latent noise.",
|
||||
),
|
||||
InputParam(
|
||||
"video_noise_sigma_max",
|
||||
default=0.135,
|
||||
type_hint=float,
|
||||
description="Maximum sigma for video latent noise.",
|
||||
),
|
||||
InputParam.template("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("image_latents"),
|
||||
OutputParam("video_latents", type_hint=torch.Tensor, description="Noisy video latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
image_latents = block_state.image_latents
|
||||
video_latents = block_state.video_latents
|
||||
num_latent_frames_per_chunk = block_state.num_latent_frames_per_chunk
|
||||
|
||||
# Add noise to first frame (single sigma)
|
||||
image_noise_sigma = (
|
||||
torch.rand(1, device=device, generator=block_state.generator)
|
||||
* (block_state.image_noise_sigma_max - block_state.image_noise_sigma_min)
|
||||
+ block_state.image_noise_sigma_min
|
||||
)
|
||||
image_latents = (
|
||||
image_noise_sigma * randn_tensor(image_latents.shape, generator=block_state.generator, device=device)
|
||||
+ (1 - image_noise_sigma) * image_latents
|
||||
)
|
||||
|
||||
# Add per-frame noise to video chunks
|
||||
noisy_latents_chunks = []
|
||||
num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk
|
||||
for i in range(num_latent_chunks):
|
||||
chunk_start = i * num_latent_frames_per_chunk
|
||||
chunk_end = chunk_start + num_latent_frames_per_chunk
|
||||
latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :]
|
||||
|
||||
chunk_frames = latent_chunk.shape[2]
|
||||
frame_sigmas = (
|
||||
torch.rand(chunk_frames, device=device, generator=block_state.generator)
|
||||
* (block_state.video_noise_sigma_max - block_state.video_noise_sigma_min)
|
||||
+ block_state.video_noise_sigma_min
|
||||
)
|
||||
frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1)
|
||||
|
||||
noisy_chunk = (
|
||||
frame_sigmas * randn_tensor(latent_chunk.shape, generator=block_state.generator, device=device)
|
||||
+ (1 - frame_sigmas) * latent_chunk
|
||||
)
|
||||
noisy_latents_chunks.append(noisy_chunk)
|
||||
video_latents = torch.cat(noisy_latents_chunks, dim=2)
|
||||
|
||||
block_state.image_latents = image_latents.to(device=device, dtype=torch.float32)
|
||||
block_state.video_latents = video_latents.to(device=device, dtype=torch.float32)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosPrepareHistoryStep(ModularPipelineBlocks):
|
||||
"""Prepares chunk/history indices and initializes history state for the chunk loop."""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Prepares the chunk loop by computing latent dimensions, number of chunks, "
|
||||
"history indices, and initializing history state (history_latents, image_latents, latent_chunks)."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", HeliosTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("height", default=384),
|
||||
InputParam.template("width", default=640),
|
||||
InputParam(
|
||||
"num_frames", default=132, type_hint=int, description="Total number of video frames to generate."
|
||||
),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam(
|
||||
"num_latent_frames_per_chunk",
|
||||
default=9,
|
||||
type_hint=int,
|
||||
description="Number of latent frames per temporal chunk.",
|
||||
),
|
||||
InputParam(
|
||||
"history_sizes",
|
||||
default=[16, 2, 1],
|
||||
type_hint=list,
|
||||
description="Sizes of long/mid/short history buffers for temporal context.",
|
||||
),
|
||||
InputParam(
|
||||
"keep_first_frame",
|
||||
default=True,
|
||||
type_hint=bool,
|
||||
description="Whether to keep the first frame as a prefix in history.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("num_latent_chunk", type_hint=int, description="Number of temporal chunks"),
|
||||
OutputParam("latent_shape", type_hint=tuple, description="Shape of latent tensor per chunk"),
|
||||
OutputParam("history_sizes", type_hint=list, description="Adjusted history sizes (sorted, descending)"),
|
||||
OutputParam("indices_hidden_states", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
||||
OutputParam("indices_latents_history_short", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
||||
OutputParam("indices_latents_history_mid", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
||||
OutputParam("indices_latents_history_long", type_hint=torch.Tensor, kwargs_type="denoiser_input_fields"),
|
||||
OutputParam("history_latents", type_hint=torch.Tensor, description="Initialized zero history latents"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
batch_size = block_state.batch_size
|
||||
device = components._execution_device
|
||||
|
||||
block_state.num_frames = max(block_state.num_frames, 1)
|
||||
history_sizes = sorted(block_state.history_sizes, reverse=True)
|
||||
|
||||
num_channels_latents = components.num_channels_latents
|
||||
h_latent = block_state.height // components.vae_scale_factor_spatial
|
||||
w_latent = block_state.width // components.vae_scale_factor_spatial
|
||||
|
||||
# Compute number of chunks
|
||||
block_state.window_num_frames = (
|
||||
block_state.num_latent_frames_per_chunk - 1
|
||||
) * components.vae_scale_factor_temporal + 1
|
||||
block_state.num_latent_chunk = max(
|
||||
1, (block_state.num_frames + block_state.window_num_frames - 1) // block_state.window_num_frames
|
||||
)
|
||||
|
||||
# Modify history_sizes for non-keep_first_frame (matching pipeline behavior)
|
||||
if not block_state.keep_first_frame:
|
||||
history_sizes = history_sizes.copy()
|
||||
history_sizes[-1] = history_sizes[-1] + 1
|
||||
|
||||
# Compute indices ONCE (same structure for all chunks)
|
||||
if block_state.keep_first_frame:
|
||||
indices = torch.arange(0, sum([1, *history_sizes, block_state.num_latent_frames_per_chunk]))
|
||||
(
|
||||
indices_prefix,
|
||||
indices_latents_history_long,
|
||||
indices_latents_history_mid,
|
||||
indices_latents_history_1x,
|
||||
indices_hidden_states,
|
||||
) = indices.split([1, *history_sizes, block_state.num_latent_frames_per_chunk], dim=0)
|
||||
indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
|
||||
else:
|
||||
indices = torch.arange(0, sum([*history_sizes, block_state.num_latent_frames_per_chunk]))
|
||||
(
|
||||
indices_latents_history_long,
|
||||
indices_latents_history_mid,
|
||||
indices_latents_history_short,
|
||||
indices_hidden_states,
|
||||
) = indices.split([*history_sizes, block_state.num_latent_frames_per_chunk], dim=0)
|
||||
|
||||
# Latent shape per chunk
|
||||
block_state.latent_shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
block_state.num_latent_frames_per_chunk,
|
||||
h_latent,
|
||||
w_latent,
|
||||
)
|
||||
|
||||
# Set outputs
|
||||
block_state.history_sizes = history_sizes
|
||||
block_state.indices_hidden_states = indices_hidden_states.unsqueeze(0)
|
||||
block_state.indices_latents_history_short = indices_latents_history_short.unsqueeze(0)
|
||||
block_state.indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)
|
||||
block_state.indices_latents_history_long = indices_latents_history_long.unsqueeze(0)
|
||||
block_state.history_latents = torch.zeros(
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
sum(history_sizes),
|
||||
h_latent,
|
||||
w_latent,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosI2VSeedHistoryStep(ModularPipelineBlocks):
|
||||
"""Seeds history_latents with fake_image_latents for I2V pipelines.
|
||||
|
||||
This small additive step runs after HeliosPrepareHistoryStep and appends fake_image_latents to the initialized
|
||||
history_latents tensor.
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "I2V history seeding: appends fake_image_latents to history_latents."
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("history_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("fake_image_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"history_latents", type_hint=torch.Tensor, description="History latents seeded with fake_image_latents"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.history_latents = torch.cat([block_state.history_latents, block_state.fake_image_latents], dim=2)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosV2VSeedHistoryStep(ModularPipelineBlocks):
|
||||
"""Seeds history_latents with video_latents for V2V pipelines.
|
||||
|
||||
This step runs after HeliosPrepareHistoryStep and replaces the tail of history_latents with video_latents. If the
|
||||
video has fewer frames than the history, the beginning of history is preserved.
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "V2V history seeding: replaces the tail of history_latents with video_latents."
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("history_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("video_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"history_latents", type_hint=torch.Tensor, description="History latents seeded with video_latents"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
history_latents = block_state.history_latents
|
||||
video_latents = block_state.video_latents
|
||||
|
||||
history_frames = history_latents.shape[2]
|
||||
video_frames = video_latents.shape[2]
|
||||
if video_frames < history_frames:
|
||||
keep_frames = history_frames - video_frames
|
||||
history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2)
|
||||
else:
|
||||
history_latents = video_latents
|
||||
|
||||
block_state.history_latents = history_latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosSetTimestepsStep(ModularPipelineBlocks):
|
||||
"""Computes scheduler parameters (mu, sigmas) for the chunk loop."""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Computes scheduler shift parameter (mu) and default sigmas for the Helios chunk loop."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", HeliosTransformer3DModel),
|
||||
ComponentSpec("scheduler", HeliosScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latent_shape", required=True, type_hint=tuple),
|
||||
InputParam.template("num_inference_steps"),
|
||||
InputParam.template("sigmas"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("mu", type_hint=float, description="Scheduler shift parameter"),
|
||||
OutputParam("sigmas", type_hint=list, description="Sigma schedule for diffusion"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
patch_size = components.transformer.config.patch_size
|
||||
latent_shape = block_state.latent_shape
|
||||
image_seq_len = (latent_shape[-1] * latent_shape[-2] * latent_shape[-3]) // (
|
||||
patch_size[0] * patch_size[1] * patch_size[2]
|
||||
)
|
||||
|
||||
if block_state.sigmas is None:
|
||||
block_state.sigmas = np.linspace(0.999, 0.0, block_state.num_inference_steps + 1)[:-1]
|
||||
|
||||
block_state.mu = calculate_shift(
|
||||
image_seq_len,
|
||||
components.scheduler.config.get("base_image_seq_len", 256),
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_shift", 0.5),
|
||||
components.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
110
src/diffusers/modular_pipelines/helios/decoders.py
Normal file
110
src/diffusers/modular_pipelines/helios/decoders.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HeliosDecodeStep(ModularPipelineBlocks):
|
||||
"""Decode all chunk latents with VAE, trim frames, and postprocess into final video output."""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Decodes all chunk latents with the VAE, concatenates them, "
|
||||
"trims to the target frame count, and postprocesses into the final video output."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam(
|
||||
"latent_chunks", required=True, type_hint=list, description="List of per-chunk denoised latent tensors"
|
||||
),
|
||||
InputParam("num_frames", required=True, type_hint=int, description="The target number of output frames"),
|
||||
InputParam.template("output_type", default="np"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"videos",
|
||||
type_hint=list[list[PIL.Image.Image]] | list[torch.Tensor] | list[np.ndarray],
|
||||
description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae = components.vae
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||
vae.device, vae.dtype
|
||||
)
|
||||
|
||||
history_video = None
|
||||
for chunk_latents in block_state.latent_chunks:
|
||||
current_latents = chunk_latents.to(vae.dtype) / latents_std + latents_mean
|
||||
current_video = vae.decode(current_latents, return_dict=False)[0]
|
||||
|
||||
if history_video is None:
|
||||
history_video = current_video
|
||||
else:
|
||||
history_video = torch.cat([history_video, current_video], dim=2)
|
||||
|
||||
# Trim to proper frame count
|
||||
generated_frames = history_video.size(2)
|
||||
generated_frames = (
|
||||
generated_frames - 1
|
||||
) // components.vae_scale_factor_temporal * components.vae_scale_factor_temporal + 1
|
||||
history_video = history_video[:, :, :generated_frames]
|
||||
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
history_video, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
|
||||
return components, state
|
||||
1069
src/diffusers/modular_pipelines/helios/denoise.py
Normal file
1069
src/diffusers/modular_pipelines/helios/denoise.py
Normal file
File diff suppressed because it is too large
Load Diff
392
src/diffusers/modular_pipelines/helios/encoders.py
Normal file
392
src/diffusers/modular_pipelines/helios/encoders.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models import AutoencoderKLWan
|
||||
from ...utils import is_ftfy_available, logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import HeliosModularPipeline
|
||||
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def prompt_clean(text):
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
return text
|
||||
|
||||
|
||||
def get_t5_prompt_embeds(
|
||||
text_encoder: UMT5EncoderModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: str | list[str],
|
||||
max_sequence_length: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""Encode text prompts into T5 embeddings for Helios.
|
||||
|
||||
Args:
|
||||
text_encoder: The T5 text encoder model.
|
||||
tokenizer: The tokenizer for the text encoder.
|
||||
prompt: The prompt or prompts to encode.
|
||||
max_sequence_length: Maximum sequence length for tokenization.
|
||||
device: Device to place tensors on.
|
||||
dtype: Optional dtype override. Defaults to `text_encoder.dtype`.
|
||||
|
||||
Returns:
|
||||
A tuple of `(prompt_embeds, attention_mask)` where `prompt_embeds` is the encoded text embeddings and
|
||||
`attention_mask` is a boolean mask.
|
||||
"""
|
||||
dtype = dtype or text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt = [prompt_clean(u) for u in prompt]
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
||||
seq_lens = mask.gt(0).sum(dim=1).long()
|
||||
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
||||
prompt_embeds = torch.stack(
|
||||
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
||||
)
|
||||
|
||||
return prompt_embeds, text_inputs.attention_mask.bool()
|
||||
|
||||
|
||||
class HeliosTextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text Encoder step that generates text embeddings to guide the video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", UMT5EncoderModel),
|
||||
ComponentSpec("tokenizer", AutoTokenizer),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 5.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("prompt"),
|
||||
InputParam.template("negative_prompt"),
|
||||
InputParam.template("max_sequence_length"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("prompt_embeds"),
|
||||
OutputParam.template("negative_prompt_embeds"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(prompt, negative_prompt):
|
||||
if prompt is not None and not isinstance(prompt, (str, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if negative_prompt is not None and not isinstance(negative_prompt, (str, list)):
|
||||
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt is not None:
|
||||
prompt_list = [prompt] if isinstance(prompt, str) else prompt
|
||||
neg_list = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
if type(prompt_list) is not type(neg_list):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
if len(prompt_list) != len(neg_list):
|
||||
raise ValueError(
|
||||
f"`negative_prompt` has batch size {len(neg_list)}, but `prompt` has batch size"
|
||||
f" {len(prompt_list)}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt = block_state.prompt
|
||||
negative_prompt = block_state.negative_prompt
|
||||
max_sequence_length = block_state.max_sequence_length
|
||||
device = components._execution_device
|
||||
|
||||
self.check_inputs(prompt, negative_prompt)
|
||||
|
||||
# Encode prompt
|
||||
block_state.prompt_embeds, _ = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Encode negative prompt
|
||||
block_state.negative_prompt_embeds = None
|
||||
if components.requires_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
if isinstance(prompt, list) and isinstance(negative_prompt, str):
|
||||
negative_prompt = len(prompt) * [negative_prompt]
|
||||
|
||||
block_state.negative_prompt_embeds, _ = get_t5_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosImageVaeEncoderStep(ModularPipelineBlocks):
|
||||
"""Encodes an input image into VAE latent space for image-to-video generation."""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Image Encoder step that encodes an input image into VAE latent space, "
|
||||
"producing image_latents (first frame prefix) and fake_image_latents (history seed) "
|
||||
"for image-to-video generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam.template("image"),
|
||||
InputParam.template("height", default=384),
|
||||
InputParam.template("width", default=640),
|
||||
InputParam(
|
||||
"num_latent_frames_per_chunk",
|
||||
default=9,
|
||||
type_hint=int,
|
||||
description="Number of latent frames per temporal chunk.",
|
||||
),
|
||||
InputParam.template("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("image_latents"),
|
||||
OutputParam(
|
||||
"fake_image_latents", type_hint=torch.Tensor, description="Fake image latents for history seeding"
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae = components.vae
|
||||
device = components._execution_device
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||
vae.device, vae.dtype
|
||||
)
|
||||
|
||||
# Preprocess image to 4D tensor (B, C, H, W)
|
||||
image = components.video_processor.preprocess(
|
||||
block_state.image, height=block_state.height, width=block_state.width
|
||||
)
|
||||
image_5d = image.unsqueeze(2).to(device=device, dtype=vae.dtype) # (B, C, 1, H, W)
|
||||
|
||||
# Encode image to get image_latents
|
||||
image_latents = vae.encode(image_5d).latent_dist.sample(generator=block_state.generator)
|
||||
image_latents = (image_latents - latents_mean) * latents_std
|
||||
|
||||
# Encode fake video to get fake_image_latents
|
||||
min_frames = (block_state.num_latent_frames_per_chunk - 1) * components.vae_scale_factor_temporal + 1
|
||||
fake_video = image_5d.repeat(1, 1, min_frames, 1, 1) # (B, C, min_frames, H, W)
|
||||
fake_latents_full = vae.encode(fake_video).latent_dist.sample(generator=block_state.generator)
|
||||
fake_latents_full = (fake_latents_full - latents_mean) * latents_std
|
||||
fake_image_latents = fake_latents_full[:, :, -1:, :, :]
|
||||
|
||||
block_state.image_latents = image_latents.to(device=device, dtype=torch.float32)
|
||||
block_state.fake_image_latents = fake_image_latents.to(device=device, dtype=torch.float32)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class HeliosVideoVaeEncoderStep(ModularPipelineBlocks):
|
||||
"""Encodes an input video into VAE latent space for video-to-video generation.
|
||||
|
||||
Produces `image_latents` (first frame) and `video_latents` (remaining frames encoded in chunks).
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Video Encoder step that encodes an input video into VAE latent space, "
|
||||
"producing image_latents (first frame) and video_latents (chunked video frames) "
|
||||
"for video-to-video generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLWan),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 8}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("video", required=True, description="Input video for video-to-video generation"),
|
||||
InputParam.template("height", default=384),
|
||||
InputParam.template("width", default=640),
|
||||
InputParam(
|
||||
"num_latent_frames_per_chunk",
|
||||
default=9,
|
||||
type_hint=int,
|
||||
description="Number of latent frames per temporal chunk.",
|
||||
),
|
||||
InputParam.template("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam.template("image_latents"),
|
||||
OutputParam("video_latents", type_hint=torch.Tensor, description="Encoded video latents (chunked)"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae = components.vae
|
||||
device = components._execution_device
|
||||
num_latent_frames_per_chunk = block_state.num_latent_frames_per_chunk
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(vae.device, vae.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(
|
||||
vae.device, vae.dtype
|
||||
)
|
||||
|
||||
# Preprocess video
|
||||
video = components.video_processor.preprocess_video(
|
||||
block_state.video, height=block_state.height, width=block_state.width
|
||||
)
|
||||
video = video.to(device=device, dtype=vae.dtype)
|
||||
|
||||
# Encode video into latents
|
||||
num_frames = video.shape[2]
|
||||
min_frames = (num_latent_frames_per_chunk - 1) * 4 + 1
|
||||
num_chunks = num_frames // min_frames
|
||||
if num_chunks == 0:
|
||||
raise ValueError(
|
||||
f"Video must have at least {min_frames} frames "
|
||||
f"(got {num_frames} frames). "
|
||||
f"Required: (num_latent_frames_per_chunk - 1) * 4 + 1 = ({num_latent_frames_per_chunk} - 1) * 4 + 1 = {min_frames}"
|
||||
)
|
||||
total_valid_frames = num_chunks * min_frames
|
||||
start_frame = num_frames - total_valid_frames
|
||||
|
||||
# Encode first frame
|
||||
first_frame = video[:, :, 0:1, :, :]
|
||||
image_latents = vae.encode(first_frame).latent_dist.sample(generator=block_state.generator)
|
||||
image_latents = (image_latents - latents_mean) * latents_std
|
||||
|
||||
# Encode remaining frames in chunks
|
||||
latents_chunks = []
|
||||
for i in range(num_chunks):
|
||||
chunk_start = start_frame + i * min_frames
|
||||
chunk_end = chunk_start + min_frames
|
||||
video_chunk = video[:, :, chunk_start:chunk_end, :, :]
|
||||
chunk_latents = vae.encode(video_chunk).latent_dist.sample(generator=block_state.generator)
|
||||
chunk_latents = (chunk_latents - latents_mean) * latents_std
|
||||
latents_chunks.append(chunk_latents)
|
||||
video_latents = torch.cat(latents_chunks, dim=2)
|
||||
|
||||
block_state.image_latents = image_latents.to(device=device, dtype=torch.float32)
|
||||
block_state.video_latents = video_latents.to(device=device, dtype=torch.float32)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
542
src/diffusers/modular_pipelines/helios/modular_blocks_helios.py
Normal file
542
src/diffusers/modular_pipelines/helios/modular_blocks_helios.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
HeliosAdditionalInputsStep,
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosSetTimestepsStep,
|
||||
HeliosTextInputStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
)
|
||||
from .decoders import HeliosDecodeStep
|
||||
from .denoise import HeliosChunkDenoiseStep, HeliosI2VChunkDenoiseStep
|
||||
from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. Vae Encoder
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Encoder step that encodes video or image inputs. This is an auto pipeline block.
|
||||
- `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.
|
||||
- `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.
|
||||
- If neither is provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
|
||||
Outputs:
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
video_latents (`Tensor`):
|
||||
Encoded video latents (chunked)
|
||||
fake_image_latents (`Tensor`):
|
||||
Fake image latents for history seeding
|
||||
"""
|
||||
|
||||
block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep]
|
||||
block_names = ["video_encoder", "image_encoder"]
|
||||
block_trigger_inputs = ["video", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Encoder step that encodes video or image inputs. This is an auto pipeline block.\n"
|
||||
" - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n"
|
||||
" - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n"
|
||||
" - If neither is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# DENOISE (T2V)
|
||||
# auto_docstring
|
||||
class HeliosCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Denoise block that takes encoded conditions and runs the chunk-based denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosSetTimestepsStep,
|
||||
HeliosChunkDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_history", "set_timesteps", "chunk_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Denoise block that takes encoded conditions and runs the chunk-based denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (I2V)
|
||||
# auto_docstring
|
||||
class HeliosI2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
I2V denoise block that seeds history with image latents and uses I2V-aware chunk preparation.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video/fake-image latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video/fake-image latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"fake_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Fake image latents used as history seed for I2V generation.",
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosSetTimestepsStep,
|
||||
HeliosI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_image",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"set_timesteps",
|
||||
"chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "I2V denoise block that seeds history with image latents and uses I2V-aware chunk preparation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (V2V)
|
||||
# auto_docstring
|
||||
class HeliosV2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
V2V denoise block that seeds history with video latents and uses I2V-aware chunk preparation.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation."
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
HeliosSetTimestepsStep,
|
||||
HeliosI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_video",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"set_timesteps",
|
||||
"chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "V2V denoise block that seeds history with video latents and uses I2V-aware chunk preparation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# AUTO DENOISE
|
||||
# auto_docstring
|
||||
class HeliosAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that selects the appropriate denoising block.
|
||||
- `HeliosV2VCoreDenoiseStep` (video2video) for video-to-video tasks.
|
||||
- `HeliosI2VCoreDenoiseStep` (image2video) for image-to-video tasks.
|
||||
- `HeliosCoreDenoiseStep` (text2video) for text-to-video tasks.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list`):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
block_classes = [HeliosV2VCoreDenoiseStep, HeliosI2VCoreDenoiseStep, HeliosCoreDenoiseStep]
|
||||
block_names = ["video2video", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["video_latents", "fake_image_latents"]
|
||||
default_block_name = "text2video"
|
||||
|
||||
def select_block(self, video_latents=None, fake_image_latents=None):
|
||||
if video_latents is not None:
|
||||
return "video2video"
|
||||
elif fake_image_latents is not None:
|
||||
return "image2video"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that selects the appropriate denoising block.\n"
|
||||
" - `HeliosV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n"
|
||||
" - `HeliosI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n"
|
||||
" - `HeliosCoreDenoiseStep` (text2video) for text-to-video tasks."
|
||||
)
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", HeliosTextEncoderStep()),
|
||||
("vae_encoder", HeliosAutoVaeEncoderStep()),
|
||||
("denoise", HeliosAutoCoreDenoiseStep()),
|
||||
("decode", HeliosDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
# ====================
|
||||
# 3. Auto Blocks
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-video, image-to-video, and video-to-video tasks using Helios.
|
||||
|
||||
Supported workflows:
|
||||
- `text2video`: requires `prompt`
|
||||
- `image2video`: requires `prompt`, `image`
|
||||
- `video2video`: requires `prompt`, `video`
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler
|
||||
(`HeliosScheduler`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for prompt encoding.
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`list`):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
timesteps (`Tensor`, *optional*):
|
||||
Timesteps for the denoising process.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "helios"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"prompt": True, "image": True},
|
||||
"video2video": {"prompt": True, "video": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for text-to-video, image-to-video, and video-to-video tasks using Helios."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
@@ -0,0 +1,520 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
HeliosAdditionalInputsStep,
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosTextInputStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
)
|
||||
from .decoders import HeliosDecodeStep
|
||||
from .denoise import HeliosPyramidChunkDenoiseStep, HeliosPyramidI2VChunkDenoiseStep
|
||||
from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. Vae Encoder
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosPyramidAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Encoder step that encodes video or image inputs. This is an auto pipeline block.
|
||||
- `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.
|
||||
- `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.
|
||||
- If neither is provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
|
||||
Outputs:
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
video_latents (`Tensor`):
|
||||
Encoded video latents (chunked)
|
||||
fake_image_latents (`Tensor`):
|
||||
Fake image latents for history seeding
|
||||
"""
|
||||
|
||||
block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep]
|
||||
block_names = ["video_encoder", "image_encoder"]
|
||||
block_trigger_inputs = ["video", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Encoder step that encodes video or image inputs. This is an auto pipeline block.\n"
|
||||
" - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n"
|
||||
" - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n"
|
||||
" - If neither is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# DENOISE (T2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
T2V pyramid denoise block with progressive multi-resolution denoising.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider
|
||||
(`ClassifierFreeZeroStarGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosPyramidChunkDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_history", "pyramid_chunk_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "T2V pyramid denoise block with progressive multi-resolution denoising."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (I2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidI2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
I2V pyramid denoise block with progressive multi-resolution denoising.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider
|
||||
(`ClassifierFreeZeroStarGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video/fake-image latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video/fake-image latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"fake_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Fake image latents used as history seed for I2V generation.",
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosPyramidI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_image",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"pyramid_chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "I2V pyramid denoise block with progressive multi-resolution denoising."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (V2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidV2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
V2V pyramid denoise block with progressive multi-resolution denoising.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider
|
||||
(`ClassifierFreeZeroStarGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation."
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
HeliosPyramidI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_video",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"pyramid_chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "V2V pyramid denoise block with progressive multi-resolution denoising."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# AUTO DENOISE
|
||||
# auto_docstring
|
||||
class HeliosPyramidAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"""
|
||||
Pyramid core denoise step that selects the appropriate denoising block.
|
||||
- `HeliosPyramidV2VCoreDenoiseStep` (video2video) for video-to-video tasks.
|
||||
- `HeliosPyramidI2VCoreDenoiseStep` (image2video) for image-to-video tasks.
|
||||
- `HeliosPyramidCoreDenoiseStep` (text2video) for text-to-video tasks.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider
|
||||
(`ClassifierFreeZeroStarGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
block_classes = [HeliosPyramidV2VCoreDenoiseStep, HeliosPyramidI2VCoreDenoiseStep, HeliosPyramidCoreDenoiseStep]
|
||||
block_names = ["video2video", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["video_latents", "fake_image_latents"]
|
||||
default_block_name = "text2video"
|
||||
|
||||
def select_block(self, video_latents=None, fake_image_latents=None):
|
||||
if video_latents is not None:
|
||||
return "video2video"
|
||||
elif fake_image_latents is not None:
|
||||
return "image2video"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pyramid core denoise step that selects the appropriate denoising block.\n"
|
||||
" - `HeliosPyramidV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n"
|
||||
" - `HeliosPyramidI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n"
|
||||
" - `HeliosPyramidCoreDenoiseStep` (text2video) for text-to-video tasks."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. Auto Blocks
|
||||
# ====================
|
||||
|
||||
PYRAMID_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", HeliosTextEncoderStep()),
|
||||
("vae_encoder", HeliosPyramidAutoVaeEncoderStep()),
|
||||
("denoise", HeliosPyramidAutoCoreDenoiseStep()),
|
||||
("decode", HeliosDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosPyramidAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for pyramid progressive generation (T2V/I2V/V2V) using Helios.
|
||||
|
||||
Supported workflows:
|
||||
- `text2video`: requires `prompt`
|
||||
- `image2video`: requires `prompt`, `image`
|
||||
- `video2video`: requires `prompt`, `video`
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler
|
||||
(`HeliosScheduler`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for prompt encoding.
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
|
||||
block_classes = PYRAMID_AUTO_BLOCKS.values()
|
||||
block_names = PYRAMID_AUTO_BLOCKS.keys()
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"prompt": True, "image": True},
|
||||
"video2video": {"prompt": True, "video": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for pyramid progressive generation (T2V/I2V/V2V) using Helios."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
@@ -0,0 +1,530 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InputParam, InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
HeliosAdditionalInputsStep,
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosTextInputStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
)
|
||||
from .decoders import HeliosDecodeStep
|
||||
from .denoise import HeliosPyramidDistilledChunkDenoiseStep, HeliosPyramidDistilledI2VChunkDenoiseStep
|
||||
from .encoders import HeliosImageVaeEncoderStep, HeliosTextEncoderStep, HeliosVideoVaeEncoderStep
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. Vae Encoder
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Encoder step for distilled pyramid pipeline.
|
||||
- `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.
|
||||
- `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.
|
||||
- If neither is provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
|
||||
Outputs:
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
video_latents (`Tensor`):
|
||||
Encoded video latents (chunked)
|
||||
fake_image_latents (`Tensor`):
|
||||
Fake image latents for history seeding
|
||||
"""
|
||||
|
||||
block_classes = [HeliosVideoVaeEncoderStep, HeliosImageVaeEncoderStep]
|
||||
block_names = ["video_encoder", "image_encoder"]
|
||||
block_trigger_inputs = ["video", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Encoder step for distilled pyramid pipeline.\n"
|
||||
" - `HeliosVideoVaeEncoderStep` (video_encoder) is used when `video` is provided.\n"
|
||||
" - `HeliosImageVaeEncoderStep` (image_encoder) is used when `image` is provided.\n"
|
||||
" - If neither is provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# DENOISE (T2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
T2V distilled pyramid denoise block with DMD scheduler and no CFG.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
is_amplify_first_chunk (`bool`, *optional*, defaults to True):
|
||||
Whether to double the first chunk's timesteps via the scheduler for amplified generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosPyramidDistilledChunkDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_history", "pyramid_chunk_denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "T2V distilled pyramid denoise block with DMD scheduler and no CFG."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (I2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledI2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
I2V distilled pyramid denoise block with DMD scheduler and no CFG.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video/fake-image latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video/fake-image latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
is_amplify_first_chunk (`bool`, *optional*, defaults to True):
|
||||
Whether to double the first chunk's timesteps via the scheduler for amplified generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"fake_image_latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="Fake image latents used as history seed for I2V generation.",
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToImageLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosI2VSeedHistoryStep,
|
||||
HeliosPyramidDistilledI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_image",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"pyramid_chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "I2V distilled pyramid denoise block with DMD scheduler and no CFG."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# DENOISE (V2V)
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledV2VCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
V2V distilled pyramid denoise block with DMD scheduler and no CFG.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`, *optional*, defaults to [16, 2, 1]):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
is_amplify_first_chunk (`bool`, *optional*, defaults to True):
|
||||
Whether to double the first chunk's timesteps via the scheduler for amplified generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
block_classes = [
|
||||
HeliosTextInputStep,
|
||||
HeliosAdditionalInputsStep(
|
||||
image_latent_inputs=[InputParam.template("image_latents")],
|
||||
additional_batch_inputs=[
|
||||
InputParam(
|
||||
"video_latents", type_hint=torch.Tensor, description="Encoded video latents for V2V generation."
|
||||
),
|
||||
],
|
||||
),
|
||||
HeliosAddNoiseToVideoLatentsStep,
|
||||
HeliosPrepareHistoryStep,
|
||||
HeliosV2VSeedHistoryStep,
|
||||
HeliosPyramidDistilledI2VChunkDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"add_noise_video",
|
||||
"prepare_history",
|
||||
"seed_history",
|
||||
"pyramid_chunk_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "V2V distilled pyramid denoise block with DMD scheduler and no CFG."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors")]
|
||||
|
||||
|
||||
# AUTO DENOISE
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
"""
|
||||
Distilled pyramid core denoise step that selects the appropriate denoising block.
|
||||
- `HeliosPyramidDistilledV2VCoreDenoiseStep` (video2video) for video-to-video tasks.
|
||||
- `HeliosPyramidDistilledI2VCoreDenoiseStep` (image2video) for image-to-video tasks.
|
||||
- `HeliosPyramidDistilledCoreDenoiseStep` (text2video) for text-to-video tasks.
|
||||
|
||||
Components:
|
||||
transformer (`HeliosTransformer3DModel`) scheduler (`HeliosScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
negative text embeddings used to guide the image generation. Can be generated from text_encoder step.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
is_amplify_first_chunk (`bool`, *optional*, defaults to True):
|
||||
Whether to double the first chunk's timesteps via the scheduler for amplified generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
|
||||
Outputs:
|
||||
latent_chunks (`list`):
|
||||
List of per-chunk denoised latent tensors
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
HeliosPyramidDistilledV2VCoreDenoiseStep,
|
||||
HeliosPyramidDistilledI2VCoreDenoiseStep,
|
||||
HeliosPyramidDistilledCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["video2video", "image2video", "text2video"]
|
||||
block_trigger_inputs = ["video_latents", "fake_image_latents"]
|
||||
default_block_name = "text2video"
|
||||
|
||||
def select_block(self, video_latents=None, fake_image_latents=None):
|
||||
if video_latents is not None:
|
||||
return "video2video"
|
||||
elif fake_image_latents is not None:
|
||||
return "image2video"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Distilled pyramid core denoise step that selects the appropriate denoising block.\n"
|
||||
" - `HeliosPyramidDistilledV2VCoreDenoiseStep` (video2video) for video-to-video tasks.\n"
|
||||
" - `HeliosPyramidDistilledI2VCoreDenoiseStep` (image2video) for image-to-video tasks.\n"
|
||||
" - `HeliosPyramidDistilledCoreDenoiseStep` (text2video) for text-to-video tasks."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. Auto Blocks
|
||||
# ====================
|
||||
|
||||
DISTILLED_PYRAMID_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", HeliosTextEncoderStep()),
|
||||
("vae_encoder", HeliosPyramidDistilledAutoVaeEncoderStep()),
|
||||
("denoise", HeliosPyramidDistilledAutoCoreDenoiseStep()),
|
||||
("decode", HeliosDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class HeliosPyramidDistilledAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for distilled pyramid progressive generation (T2V/I2V/V2V) using Helios.
|
||||
|
||||
Supported workflows:
|
||||
- `text2video`: requires `prompt`
|
||||
- `image2video`: requires `prompt`, `image`
|
||||
- `video2video`: requires `prompt`, `video`
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`HeliosTransformer3DModel`) scheduler
|
||||
(`HeliosScheduler`)
|
||||
|
||||
Inputs:
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
negative_prompt (`str`, *optional*):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length for prompt encoding.
|
||||
video (`None`, *optional*):
|
||||
Input video for video-to-video generation
|
||||
height (`int`, *optional*, defaults to 384):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 640):
|
||||
The width in pixels of the generated image.
|
||||
num_latent_frames_per_chunk (`int`, *optional*, defaults to 9):
|
||||
Number of latent frames per temporal chunk.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of videos to generate per prompt.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
video_latents (`Tensor`, *optional*):
|
||||
Encoded video latents for V2V generation.
|
||||
image_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for image latent noise.
|
||||
image_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for image latent noise.
|
||||
video_noise_sigma_min (`float`, *optional*, defaults to 0.111):
|
||||
Minimum sigma for video latent noise.
|
||||
video_noise_sigma_max (`float`, *optional*, defaults to 0.135):
|
||||
Maximum sigma for video latent noise.
|
||||
num_frames (`int`, *optional*, defaults to 132):
|
||||
Total number of video frames to generate.
|
||||
history_sizes (`list`):
|
||||
Sizes of long/mid/short history buffers for temporal context.
|
||||
keep_first_frame (`bool`, *optional*, defaults to True):
|
||||
Whether to keep the first frame as a prefix in history.
|
||||
pyramid_num_inference_steps_list (`list`, *optional*, defaults to [10, 10, 10]):
|
||||
Number of denoising steps per pyramid stage.
|
||||
latents (`Tensor`, *optional*):
|
||||
Pre-generated noisy latents for image generation.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
is_amplify_first_chunk (`bool`, *optional*, defaults to True):
|
||||
Whether to double the first chunk's timesteps via the scheduler for amplified generation.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
fake_image_latents (`Tensor`, *optional*):
|
||||
Fake image latents used as history seed for I2V generation.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "helios-pyramid"
|
||||
|
||||
block_classes = DISTILLED_PYRAMID_AUTO_BLOCKS.values()
|
||||
block_names = DISTILLED_PYRAMID_AUTO_BLOCKS.keys()
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"prompt": True, "image": True},
|
||||
"video2video": {"prompt": True, "video": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for distilled pyramid progressive generation (T2V/I2V/V2V) using Helios."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
87
src/diffusers/modular_pipelines/helios/modular_pipeline.py
Normal file
87
src/diffusers/modular_pipelines/helios/modular_pipeline.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...loaders import HeliosLoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class HeliosModularPipeline(
|
||||
ModularPipeline,
|
||||
HeliosLoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for Helios text-to-video generation.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "HeliosAutoBlocks"
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.config.scale_factor_spatial
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def vae_scale_factor_temporal(self):
|
||||
vae_scale_factor = 4
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.config.scale_factor_temporal
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def num_channels_latents(self):
|
||||
# YiYi TODO: find out default value
|
||||
num_channels_latents = 16
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
return num_channels_latents
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires_unconditional_embeds = False
|
||||
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
|
||||
|
||||
return requires_unconditional_embeds
|
||||
|
||||
|
||||
class HeliosPyramidModularPipeline(HeliosModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Helios pyramid (progressive resolution) video generation.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "HeliosPyramidAutoBlocks"
|
||||
|
||||
|
||||
class HeliosPyramidDistilledModularPipeline(HeliosModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for Helios distilled pyramid video generation using DMD scheduler.
|
||||
|
||||
Uses guidance_scale=1.0 (no CFG) and supports is_amplify_first_chunk for the DMD scheduler.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "HeliosPyramidDistilledAutoBlocks"
|
||||
@@ -106,6 +106,16 @@ def _wan_i2v_map_fn(config_dict=None):
|
||||
return "WanImage2VideoModularPipeline"
|
||||
|
||||
|
||||
def _helios_pyramid_map_fn(config_dict=None):
|
||||
if config_dict is None:
|
||||
return "HeliosPyramidModularPipeline"
|
||||
|
||||
if config_dict.get("is_distilled", False):
|
||||
return "HeliosPyramidDistilledModularPipeline"
|
||||
else:
|
||||
return "HeliosPyramidModularPipeline"
|
||||
|
||||
|
||||
MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
[
|
||||
("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")),
|
||||
@@ -120,6 +130,8 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("qwenimage-edit-plus", _create_default_map_fn("QwenImageEditPlusModularPipeline")),
|
||||
("qwenimage-layered", _create_default_map_fn("QwenImageLayeredModularPipeline")),
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("helios", _create_default_map_fn("HeliosModularPipeline")),
|
||||
("helios-pyramid", _helios_pyramid_map_fn),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
|
||||
@@ -95,6 +95,7 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
@@ -185,6 +186,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
("prx", PRXPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -82,13 +82,16 @@ EXAMPLE_DOC_STRING = """
|
||||
```python
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
|
||||
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
|
||||
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
|
||||
>>> controlnet = AutoModel.from_pretrained(
|
||||
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
|
||||
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
||||
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
from .pipeline_flux2_klein import Flux2KleinPipeline
|
||||
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,886 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import Flux2KleinKVPipeline
|
||||
|
||||
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> ref_image = Image.open("reference.png")
|
||||
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
|
||||
>>> image.save("flux2_kv_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
|
||||
|
||||
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
|
||||
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
|
||||
inference when using reference images.
|
||||
|
||||
Reference:
|
||||
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3ForCausalLM`]):
|
||||
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
||||
tokenizer (`Qwen2TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
is_distilled: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Set KV-cache-aware attention processors
|
||||
self._set_kv_attn_processors()
|
||||
|
||||
@staticmethod
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: str | list[str],
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: list[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: torch.Tensor | None = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
||||
def _prepare_image_ids(
|
||||
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (list[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def _set_kv_attn_processors(self):
|
||||
"""Replace default attention processors with KV-cache-aware variants."""
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in self.transformer.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: torch.Tensor | None = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
|
||||
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
|
||||
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
|
||||
recomputing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 4):
|
||||
The number of denoising steps.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
Generator(s) for deterministic generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: `"pil"` or `"np"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Flux2PipelineOutput` or a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs passed to attention processors.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Callback function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
Tensor inputs for the callback function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the prompt.
|
||||
text_encoder_out_layers (`tuple[int]`):
|
||||
Layer indices for text encoder hidden state extraction.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop with KV caching
|
||||
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
|
||||
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
|
||||
# No ref images: standard forward
|
||||
self.scheduler.set_begin_index(0)
|
||||
kv_cache = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if i == 0 and image_latents is not None:
|
||||
# Step 0: include ref tokens, extract KV cache
|
||||
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
|
||||
|
||||
noise_pred, kv_cache = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=image_latents.shape[1],
|
||||
)
|
||||
|
||||
elif kv_cache is not None:
|
||||
# Steps 1+: use cached ref KV, no ref tokens in input
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# No reference images: standard forward
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Clean up KV cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.clear()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
@@ -456,6 +456,8 @@ class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
|
||||
# the output will be non-deterministic and may produce incorrect results in CP context.
|
||||
if generator is None:
|
||||
generator = torch.Generator(device=device)
|
||||
elif isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
|
||||
gamma = self.scheduler.config.gamma
|
||||
_, ph, pw = patch_size
|
||||
@@ -470,7 +472,8 @@ class HeliosPyramidPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
|
||||
|
||||
L = torch.linalg.cholesky(cov)
|
||||
block_number = batch_size * channel * num_frames * (height // ph) * (width // pw)
|
||||
z = torch.randn(block_number, block_size, device=device, generator=generator)
|
||||
z = torch.randn(block_number, block_size, generator=generator, device=generator.device)
|
||||
z = z.to(device=device)
|
||||
noise = z @ L.T
|
||||
|
||||
noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw)
|
||||
|
||||
@@ -36,7 +36,7 @@ from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoConfig", "1.0.0", deprecation_message)
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
|
||||
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
|
||||
@@ -152,6 +152,96 @@ class FluxModularPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPyramidAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPyramidDistilledAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPyramidDistilledModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HeliosPyramidModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1112,6 +1202,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -14,15 +14,16 @@
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import logging
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.hooks import HookRegistry, ModelHook
|
||||
from diffusers.models import ModelMixin
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils import get_logger
|
||||
from diffusers.utils.import_utils import compare_versions
|
||||
|
||||
from ..testing_utils import (
|
||||
@@ -218,18 +219,20 @@ class NestedContainer(torch.nn.Module):
|
||||
|
||||
|
||||
@require_torch_accelerator
|
||||
class TestGroupOffload:
|
||||
class GroupOffloadTests(unittest.TestCase):
|
||||
in_features = 64
|
||||
hidden_features = 256
|
||||
out_features = 64
|
||||
num_layers = 4
|
||||
|
||||
def setup_method(self):
|
||||
def setUp(self):
|
||||
with torch.no_grad():
|
||||
self.model = self.get_model()
|
||||
self.input = torch.randn((4, self.in_features)).to(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
|
||||
del self.model
|
||||
del self.input
|
||||
gc.collect()
|
||||
@@ -245,20 +248,18 @@ class TestGroupOffload:
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
torch.device(torch_device).type not in ["cuda", "xpu"],
|
||||
reason="Test requires a CUDA or XPU device.",
|
||||
)
|
||||
def test_offloading_forward_pass(self):
|
||||
@torch.no_grad()
|
||||
def run_forward(model):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
backend_reset_peak_memory_stats(torch_device)
|
||||
assert all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
self.assertTrue(
|
||||
all(
|
||||
module._diffusers_hook.get_hook("group_offloading") is not None
|
||||
for module in model.modules()
|
||||
if hasattr(module, "_diffusers_hook")
|
||||
)
|
||||
)
|
||||
model.eval()
|
||||
output = model(self.input)[0].cpu()
|
||||
@@ -290,37 +291,41 @@ class TestGroupOffload:
|
||||
output_with_group_offloading5, mem5 = run_forward(model)
|
||||
|
||||
# Precision assertions - offloading should not impact the output
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
|
||||
assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
|
||||
|
||||
# Memory assertions - offloading should reduce memory usage
|
||||
assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
|
||||
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
|
||||
|
||||
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
|
||||
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
|
||||
logger = get_logger("diffusers.models.modeling_utils")
|
||||
logger.setLevel("INFO")
|
||||
with self.assertLogs(logger, level="WARNING") as cm:
|
||||
self.model.to(torch_device)
|
||||
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
|
||||
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
|
||||
|
||||
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
|
||||
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
pipe = DummyPipeline(self.model)
|
||||
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"):
|
||||
logger = get_logger("diffusers.pipelines.pipeline_utils")
|
||||
logger.setLevel("INFO")
|
||||
with self.assertLogs(logger, level="WARNING") as cm:
|
||||
pipe.to(torch_device)
|
||||
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
|
||||
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
|
||||
|
||||
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
|
||||
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
|
||||
original_is_available = torch_accelerator_module.is_available
|
||||
torch_accelerator_module.is_available = lambda: False
|
||||
with pytest.raises(ValueError):
|
||||
with self.assertRaises(ValueError):
|
||||
self.model.enable_group_offload(
|
||||
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
|
||||
)
|
||||
@@ -328,31 +333,31 @@ class TestGroupOffload:
|
||||
|
||||
def test_error_raised_if_supports_group_offloading_false(self):
|
||||
self.model._supports_group_offloading = False
|
||||
with pytest.raises(ValueError, match="does not support group offloading"):
|
||||
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
|
||||
self.model.enable_group_offload(onload_device=torch.device(torch_device))
|
||||
|
||||
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
|
||||
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
|
||||
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.enable_model_cpu_offload()
|
||||
with pytest.raises(ValueError, match="Cannot apply group offloading"):
|
||||
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
|
||||
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
|
||||
pipe = DummyPipeline(self.model)
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
with pytest.raises(ValueError, match="Cannot apply group offloading"):
|
||||
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
|
||||
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
|
||||
|
||||
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
|
||||
@@ -371,12 +376,12 @@ class TestGroupOffload:
|
||||
context = contextlib.nullcontext()
|
||||
if compare_versions("diffusers", "<=", "0.33.0"):
|
||||
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
|
||||
context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")
|
||||
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
|
||||
|
||||
with context:
|
||||
model(self.input)
|
||||
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
return
|
||||
@@ -402,14 +407,14 @@ class TestGroupOffload:
|
||||
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
|
||||
|
||||
num_repeats = 2
|
||||
for i in range(num_repeats):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
|
||||
|
||||
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
|
||||
assert ref_name == name
|
||||
@@ -423,7 +428,9 @@ class TestGroupOffload:
|
||||
absdiff = diff.abs()
|
||||
absmax = absdiff.max().item()
|
||||
cumulated_absmax += absmax
|
||||
assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
self.assertLess(
|
||||
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
|
||||
)
|
||||
|
||||
def test_vae_like_model_without_streams(self):
|
||||
"""Test VAE-like model with block-level offloading but without streams."""
|
||||
@@ -445,7 +452,9 @@ class TestGroupOffload:
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
|
||||
)
|
||||
|
||||
def test_model_with_only_standalone_layers(self):
|
||||
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
|
||||
@@ -466,11 +475,12 @@ class TestGroupOffload:
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), (
|
||||
f"Outputs do not match at iteration {i} for model with standalone layers."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for model with standalone layers.",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
|
||||
@parameterized.expand([("block_level",), ("leaf_level",)])
|
||||
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
|
||||
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
|
||||
if torch.device(torch_device).type not in ["cuda", "xpu"]:
|
||||
@@ -491,8 +501,9 @@ class TestGroupOffload:
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), (
|
||||
f"Outputs do not match for standalone Conv layers with {offload_type}."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match for standalone Conv layers with {offload_type}.",
|
||||
)
|
||||
|
||||
def test_multiple_invocations_with_vae_like_model(self):
|
||||
@@ -515,7 +526,7 @@ class TestGroupOffload:
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x).sample
|
||||
out = model(x).sample
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."
|
||||
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
|
||||
|
||||
def test_nested_container_parameters_offloading(self):
|
||||
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
|
||||
@@ -536,8 +547,9 @@ class TestGroupOffload:
|
||||
for i in range(2):
|
||||
out_ref = model_ref(x)
|
||||
out = model(x)
|
||||
assert torch.allclose(out_ref, out, atol=1e-5), (
|
||||
f"Outputs do not match at iteration {i} for nested parameters."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref, out, atol=1e-5),
|
||||
f"Outputs do not match at iteration {i} for nested parameters.",
|
||||
)
|
||||
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
@@ -590,7 +602,7 @@ class DummyModelWithConditionalModules(ModelMixin):
|
||||
return x
|
||||
|
||||
|
||||
class TestConditionalModuleGroupOffload(TestGroupOffload):
|
||||
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
|
||||
"""Tests for conditionally-executed modules under group offloading with streams.
|
||||
|
||||
Regression tests for the case where a module is not executed during the first forward pass
|
||||
@@ -608,10 +620,10 @@ class TestConditionalModuleGroupOffload(TestGroupOffload):
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
|
||||
@pytest.mark.skipif(
|
||||
@parameterized.expand([("leaf_level",), ("block_level",)])
|
||||
@unittest.skipIf(
|
||||
torch.device(torch_device).type not in ["cuda", "xpu"],
|
||||
reason="Test requires a CUDA or XPU device.",
|
||||
"Test requires a CUDA or XPU device.",
|
||||
)
|
||||
def test_conditional_modules_with_stream(self, offload_type: str):
|
||||
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
|
||||
@@ -658,20 +670,23 @@ class TestConditionalModuleGroupOffload(TestGroupOffload):
|
||||
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
|
||||
out_ref_no_opt = model_ref(x, optional_input=None)
|
||||
out_no_opt = model(x, optional_input=None)
|
||||
assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
|
||||
f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
|
||||
)
|
||||
|
||||
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
|
||||
out_ref_with_opt = model_ref(x, optional_input=optional_input)
|
||||
out_with_opt = model(x, optional_input=optional_input)
|
||||
assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
|
||||
f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
|
||||
)
|
||||
|
||||
# Third pass again without optional_input — verify stable behavior.
|
||||
out_ref_no_opt2 = model_ref(x, optional_input=None)
|
||||
out_no_opt2 = model(x, optional_input=None)
|
||||
assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
|
||||
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
|
||||
self.assertTrue(
|
||||
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
|
||||
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
|
||||
)
|
||||
|
||||
@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
master_port,
|
||||
model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
):
|
||||
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
|
||||
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
|
||||
],
|
||||
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
|
||||
)
|
||||
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
master_port = _find_free_port()
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
mp.spawn(
|
||||
_custom_mesh_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
0
tests/modular_pipelines/helios/__init__.py
Normal file
0
tests/modular_pipelines/helios/__init__.py
Normal file
166
tests/modular_pipelines/helios/test_modular_pipeline_helios.py
Normal file
166
tests/modular_pipelines/helios/test_modular_pipeline_helios.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from diffusers.modular_pipelines import (
|
||||
HeliosAutoBlocks,
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidAutoBlocks,
|
||||
HeliosPyramidModularPipeline,
|
||||
)
|
||||
|
||||
from ..test_modular_pipelines_common import ModularPipelineTesterMixin
|
||||
|
||||
|
||||
HELIOS_WORKFLOWS = {
|
||||
"text2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.set_timesteps", "HeliosSetTimestepsStep"),
|
||||
("denoise.chunk_denoise", "HeliosChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
"image2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("vae_encoder", "HeliosImageVaeEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.additional_inputs", "HeliosAdditionalInputsStep"),
|
||||
("denoise.add_noise_image", "HeliosAddNoiseToImageLatentsStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.seed_history", "HeliosI2VSeedHistoryStep"),
|
||||
("denoise.set_timesteps", "HeliosSetTimestepsStep"),
|
||||
("denoise.chunk_denoise", "HeliosI2VChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
"video2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("vae_encoder", "HeliosVideoVaeEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.additional_inputs", "HeliosAdditionalInputsStep"),
|
||||
("denoise.add_noise_video", "HeliosAddNoiseToVideoLatentsStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.seed_history", "HeliosV2VSeedHistoryStep"),
|
||||
("denoise.set_timesteps", "HeliosSetTimestepsStep"),
|
||||
("denoise.chunk_denoise", "HeliosI2VChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestHeliosModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = HeliosModularPipeline
|
||||
pipeline_blocks_class = HeliosAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-helios-modular-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "num_frames"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
optional_params = frozenset(["num_inference_steps", "num_videos_per_prompt", "latents"])
|
||||
output_name = "videos"
|
||||
expected_workflow_blocks = HELIOS_WORKFLOWS
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 16,
|
||||
"width": 16,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
@pytest.mark.skip(reason="num_videos_per_prompt")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
|
||||
|
||||
HELIOS_PYRAMID_WORKFLOWS = {
|
||||
"text2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.pyramid_chunk_denoise", "HeliosPyramidChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
"image2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("vae_encoder", "HeliosImageVaeEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.additional_inputs", "HeliosAdditionalInputsStep"),
|
||||
("denoise.add_noise_image", "HeliosAddNoiseToImageLatentsStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.seed_history", "HeliosI2VSeedHistoryStep"),
|
||||
("denoise.pyramid_chunk_denoise", "HeliosPyramidI2VChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
"video2video": [
|
||||
("text_encoder", "HeliosTextEncoderStep"),
|
||||
("vae_encoder", "HeliosVideoVaeEncoderStep"),
|
||||
("denoise.input", "HeliosTextInputStep"),
|
||||
("denoise.additional_inputs", "HeliosAdditionalInputsStep"),
|
||||
("denoise.add_noise_video", "HeliosAddNoiseToVideoLatentsStep"),
|
||||
("denoise.prepare_history", "HeliosPrepareHistoryStep"),
|
||||
("denoise.seed_history", "HeliosV2VSeedHistoryStep"),
|
||||
("denoise.pyramid_chunk_denoise", "HeliosPyramidI2VChunkDenoiseStep"),
|
||||
("decode", "HeliosDecodeStep"),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class TestHeliosPyramidModularPipelineFast(ModularPipelineTesterMixin):
|
||||
pipeline_class = HeliosPyramidModularPipeline
|
||||
pipeline_blocks_class = HeliosPyramidAutoBlocks
|
||||
pretrained_model_name_or_path = "hf-internal-testing/tiny-helios-pyramid-modular-pipe"
|
||||
|
||||
params = frozenset(["prompt", "height", "width", "num_frames"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
optional_params = frozenset(["pyramid_num_inference_steps_list", "num_videos_per_prompt", "latents"])
|
||||
output_name = "videos"
|
||||
expected_workflow_blocks = HELIOS_PYRAMID_WORKFLOWS
|
||||
|
||||
def get_dummy_inputs(self, seed=0):
|
||||
generator = self.get_generator(seed)
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"pyramid_num_inference_steps_list": [2, 2],
|
||||
"height": 64,
|
||||
"width": 64,
|
||||
"num_frames": 9,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
# Pyramid pipeline injects noise at each stage, so batch vs single can differ more
|
||||
super().test_inference_batch_single_identical(expected_max_diff=5e-1)
|
||||
|
||||
@pytest.mark.skip(reason="Pyramid multi-stage noise makes offload comparison unreliable with tiny models")
|
||||
def test_components_auto_cpu_offload_inference_consistent(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="Pyramid multi-stage noise makes save/load comparison unreliable with tiny models")
|
||||
def test_save_from_pretrained(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="num_videos_per_prompt")
|
||||
def test_num_images_per_prompt(self):
|
||||
pass
|
||||
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2KleinKVPipeline
|
||||
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256,
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
guidance_embeds=False,
|
||||
)
|
||||
|
||||
# Create minimal Qwen3 config
|
||||
config = Qwen3Config(
|
||||
intermediate_size=16,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Qwen3ForCausalLM(config)
|
||||
|
||||
# Use a simple tokenizer for testing
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"image": Image.new("RGB", (64, 64)),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 64,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_without_image(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["image"]
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image.shape, (1, 8, 8, 3))
|
||||
|
||||
@unittest.skip("Needs to be revisited")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user